This commit is contained in:
孟帅
2023-07-20 18:01:10 +08:00
parent 9113fc5297
commit 373d9627fb
492 changed files with 12170 additions and 6982 deletions

View File

@@ -52,10 +52,9 @@ func loadPermissions(ctx context.Context) {
Permissions string `json:"permissions"`
}
var (
rules [][]string
polices []*Policy
err error
superRoleKey = g.Cfg().MustGet(ctx, "hotgo.admin.superRoleKey")
rules [][]string
polices []*Policy
err error
)
err = g.Model("hg_admin_role r").
@@ -65,7 +64,7 @@ func loadPermissions(ctx context.Context) {
Where("r.status", consts.StatusEnabled).
Where("m.status", consts.StatusEnabled).
Where("m.permissions !=?", "").
Where("r.key !=?", superRoleKey.String()).
Where("r.key !=?", consts.SuperRoleKey).
Scan(&polices)
if err != nil {
g.Log().Fatalf(ctx, "loadPermissions Scan err:%v", err)

View File

@@ -25,9 +25,11 @@ var crons = &cronManager{
tasks: make(map[string]*TaskItem),
}
// cronStrategy 任务接口
type cronStrategy interface {
// Cron 定时任务接口
type Cron interface {
// GetName 获取任务名称
GetName() string
// Execute 执行一次任务
Execute(ctx context.Context)
}
@@ -51,7 +53,7 @@ func Logger() *glog.Logger {
}
// Register 注册任务
func Register(c cronStrategy) {
func Register(c Cron) {
crons.Lock()
defer crons.Unlock()
@@ -72,6 +74,9 @@ func StopALL() {
// StartALL 启动所有任务
func StartALL(sysCron []*entity.SysCron) (err error) {
crons.Lock()
defer crons.Unlock()
if len(crons.tasks) == 0 {
g.Log().Debug(gctx.GetInitCtx(), "no scheduled task is available.")
return
@@ -156,6 +161,9 @@ func Stop(sysCron *entity.SysCron) (err error) {
// Once 立即执行一次某个任务
func Once(ctx context.Context, sysCron *entity.SysCron) error {
crons.RLock()
defer crons.RUnlock()
for _, v := range crons.tasks {
if v.Name == sysCron.Name {
simple.SafeGo(ctx, func(ctx context.Context) {

View File

@@ -6,6 +6,9 @@
package hggen
import (
_ "hotgo/internal/library/hggen/internal/cmd/gendao"
_ "unsafe"
"context"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
@@ -23,6 +26,9 @@ import (
"sort"
)
//go:linkname doGenDaoForArray hotgo/internal/library/hggen/internal/cmd/gendao.doGenDaoForArray
func doGenDaoForArray(ctx context.Context, index int, in gendao.CGenDaoInput)
// Dao 生成数据库实体
func Dao(ctx context.Context) (err error) {
for _, v := range daoConfig {
@@ -31,7 +37,7 @@ func Dao(ctx context.Context) (err error) {
if err != nil {
return
}
gendao.DoGenDaoForArray(ctx, inp)
doGenDaoForArray(ctx, -1, inp)
}
return
}
@@ -52,11 +58,11 @@ func ServiceWithCfg(ctx context.Context, cfg ...genservice.CGenServiceInput) (er
}
// TableColumns 获取指定表生成字段列表
func TableColumns(ctx context.Context, in sysin.GenCodesColumnListInp) (fields []*sysin.GenCodesColumnListModel, err error) {
func TableColumns(ctx context.Context, in *sysin.GenCodesColumnListInp) (fields []*sysin.GenCodesColumnListModel, err error) {
return views.DoTableColumns(ctx, in, GetDaoConfig(in.Name))
}
func TableSelects(ctx context.Context, in sysin.GenCodesSelectsInp) (res *sysin.GenCodesSelectsModel, err error) {
func TableSelects(ctx context.Context, in *sysin.GenCodesSelectsInp) (res *sysin.GenCodesSelectsModel, err error) {
res = new(sysin.GenCodesSelectsModel)
res.GenType, err = GenTypeSelect(ctx)
if err != nil {
@@ -110,7 +116,7 @@ func TableSelects(ctx context.Context, in sysin.GenCodesSelectsInp) (res *sysin.
}
sort.Sort(res.FormRole)
dictMode, err := service.SysDictType().TreeSelect(ctx, sysin.DictTreeSelectInp{})
dictMode, err := service.SysDictType().TreeSelect(ctx, &sysin.DictTreeSelectInp{})
if err != nil {
return
}
@@ -183,7 +189,7 @@ func DbSelect(ctx context.Context) (res form.Selects) {
}
// Preview 生成预览
func Preview(ctx context.Context, in sysin.GenCodesPreviewInp) (res *sysin.GenCodesPreviewModel, err error) {
func Preview(ctx context.Context, in *sysin.GenCodesPreviewInp) (res *sysin.GenCodesPreviewModel, err error) {
genConfig, err := service.SysConfig().GetLoadGenerate(ctx)
if err != nil {
return nil, err
@@ -209,7 +215,7 @@ func Preview(ctx context.Context, in sysin.GenCodesPreviewInp) (res *sysin.GenCo
}
// Build 提交生成
func Build(ctx context.Context, in sysin.GenCodesBuildInp) (err error) {
func Build(ctx context.Context, in *sysin.GenCodesBuildInp) (err error) {
genConfig, err := service.SysConfig().GetLoadGenerate(ctx)
if err != nil {
return err
@@ -217,7 +223,7 @@ func Build(ctx context.Context, in sysin.GenCodesBuildInp) (err error) {
switch in.GenType {
case consts.GenCodesTypeCurd:
pin := sysin.GenCodesPreviewInp(in)
pin := &sysin.GenCodesPreviewInp{SysGenCodes: in.SysGenCodes}
return views.Curd.DoBuild(ctx, &views.CurdBuildInput{
PreviewIn: &views.CurdPreviewInput{
In: pin,

View File

@@ -1,9 +1,16 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"context"
"strings"
"github.com/gogf/gf/v2"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gcmd"
"github.com/gogf/gf/v2/util/gtag"
@@ -48,17 +55,21 @@ func (c cGF) Index(ctx context.Context, in cGFInput) (out *cGFOutput, err error)
_, err = Version.Index(ctx, cVersionInput{})
return
}
answer := "n"
// No argument or option, do installation checks.
if !service.Install.IsInstalled() {
if data, isInstalled := service.Install.IsInstalled(); !isInstalled {
mlog.Print("hi, it seams it's the first time you installing gf cli.")
s := gcmd.Scanf("do you want to install gf binary to your system? [y/n]: ")
if strings.EqualFold(s, "y") {
if err = service.Install.Run(ctx); err != nil {
return
}
gcmd.Scan("press `Enter` to exit...")
answer = gcmd.Scanf("do you want to install gf(%s) binary to your system? [y/n]: ", gf.VERSION)
} else if !data.IsSelf {
mlog.Print("hi, you have installed gf cli.")
answer = gcmd.Scanf("do you want to install gf(%s) binary to your system? [y/n]: ", gf.VERSION)
}
if strings.EqualFold(answer, "y") {
if err = service.Install.Run(ctx); err != nil {
return
}
gcmd.Scan("press `Enter` to exit...")
return
}
// Print help content.
gcmd.CommandFromCtx(ctx).Print()

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
@@ -123,15 +129,21 @@ type cBuildInput struct {
VarMap g.Map `short:"r" name:"varMap" brief:"custom built embedded variable into binary"`
PackSrc string `short:"ps" name:"packSrc" brief:"pack one or more folders into one go file before building"`
PackDst string `short:"pd" name:"packDst" brief:"temporary go file path for pack, this go file will be automatically removed after built" d:"internal/packed/build_pack_data.go"`
ExitWhenError bool `short:"ew" name:"exitWhenError" brief:"exit building when any error occurs, default is false" orphan:"true"`
ExitWhenError bool `short:"ew" name:"exitWhenError" brief:"exit building when any error occurs, specially for multiple arch and system buildings. default is false" orphan:"true"`
DumpENV bool `short:"de" name:"dumpEnv" brief:"dump current go build environment before building binary" orphan:"true"`
}
type cBuildOutput struct{}
func (c cBuild) Index(ctx context.Context, in cBuildInput) (out *cBuildOutput, err error) {
// print used go env
if in.DumpENV {
_, _ = Env.Index(ctx, cEnvInput{})
}
mlog.SetHeaderPrint(true)
mlog.Debugf(`build input: %+v`, in)
mlog.Debugf(`build command input: %+v`, in)
// Necessary check.
if gproc.SearchBinary("go") == "" {
mlog.Fatalf(`command "go" not found in your environment, please install golang first to proceed this command`)
@@ -236,7 +248,7 @@ func (c cBuild) Index(ctx context.Context, in cBuildInput) (out *cBuildOutput, e
if len(customSystems) > 0 && customSystems[0] != "all" && !gstr.InArray(customSystems, system) {
continue
}
for arch, _ := range item {
for arch := range item {
if len(customArches) > 0 && customArches[0] != "all" && !gstr.InArray(customArches, arch) {
continue
}
@@ -299,8 +311,9 @@ func (c cBuild) getBuildInVarStr(ctx context.Context, in cBuildInput) string {
if buildInVarMap == nil {
buildInVarMap = make(g.Map)
}
buildInVarMap["builtGit"] = c.getGitCommit(ctx)
buildInVarMap["builtTime"] = gtime.Now().String()
buildInVarMap[`builtGit`] = c.getGitCommit(ctx)
buildInVarMap[`builtTime`] = gtime.Now().String()
buildInVarMap[`builtVersion`] = in.Version
b, err := json.Marshal(buildInVarMap)
if err != nil {
mlog.Fatal(err)

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
@@ -33,6 +39,7 @@ gf docker main.go
gf docker main.go -t hub.docker.com/john/image:tag
gf docker main.go -t hub.docker.com/john/image:tag
gf docker main.go -p -t hub.docker.com/john/image:tag
gf docker main.go -p -tp ["hub.docker.com/john","hub.docker.com/smith"] -tn image:tag
`
cDockerDc = `
The "docker" command builds the GF project to a docker images.
@@ -45,6 +52,7 @@ You should have docker installed, and there must be a Dockerfile in the root of
cDockerFileBrief = `file path of the Dockerfile. it's "manifest/docker/Dockerfile" in default`
cDockerShellBrief = `path of the shell file which is executed before docker build`
cDockerPushBrief = `auto push the docker image to docker registry if "-t" option passed`
cDockerTagBrief = `full tag for this docker, pattern like "xxx.xxx.xxx/image:tag"`
cDockerTagNameBrief = `tag name for this docker, pattern like "image:tag". this option is required with TagPrefixes`
cDockerTagPrefixesBrief = `tag prefixes for this docker, which are used for docker push. this option is required with TagName`
cDockerExtraBrief = `extra build options passed to "docker image"`
@@ -61,6 +69,7 @@ func init() {
`cDockerShellBrief`: cDockerShellBrief,
`cDockerBuildBrief`: cDockerBuildBrief,
`cDockerPushBrief`: cDockerPushBrief,
`cDockerTagBrief`: cDockerTagBrief,
`cDockerTagNameBrief`: cDockerTagNameBrief,
`cDockerTagPrefixesBrief`: cDockerTagPrefixesBrief,
`cDockerExtraBrief`: cDockerExtraBrief,
@@ -69,10 +78,11 @@ func init() {
type cDockerInput struct {
g.Meta `name:"docker" config:"gfcli.docker"`
Main string `name:"MAIN" arg:"true" brief:"{cDockerMainBrief}" d:"main.go"`
Main string `name:"MAIN" arg:"true" brief:"{cDockerMainBrief}" d:"main.go"`
File string `name:"file" short:"f" brief:"{cDockerFileBrief}" d:"manifest/docker/Dockerfile"`
Shell string `name:"shell" short:"s" brief:"{cDockerShellBrief}" d:"manifest/docker/docker.sh"`
Build string `name:"build" short:"b" brief:"{cDockerBuildBrief}" d:"-a amd64 -s linux"`
Build string `name:"build" short:"b" brief:"{cDockerBuildBrief}"`
Tag string `name:"tag" short:"t" brief:"{cDockerTagBrief}"`
TagName string `name:"tagName" short:"tn" brief:"{cDockerTagNameBrief}" v:"required-with:TagPrefixes"`
TagPrefixes []string `name:"tagPrefixes" short:"tp" brief:"{cDockerTagPrefixesBrief}" v:"required-with:TagName"`
Push bool `name:"push" short:"p" brief:"{cDockerPushBrief}" orphan:"true"`
@@ -87,17 +97,23 @@ func (c cDocker) Index(ctx context.Context, in cDockerInput) (out *cDockerOutput
mlog.Fatalf(`command "docker" not found in your environment, please install docker first to proceed this command`)
}
mlog.Debugf(`docker command input: %+v`, in)
// Binary build.
in.Build += " --exit"
if in.Main != "" {
if err = gproc.ShellRun(ctx, fmt.Sprintf(`gf build %s %s`, in.Main, in.Build)); err != nil {
return
if in.Main != "" && in.Build != "" {
in.Build += " --exitWhenError"
if in.Main != "" {
if err = gproc.ShellRun(ctx, fmt.Sprintf(`gf build %s %s`, in.Main, in.Build)); err != nil {
mlog.Debugf(`build binary failed with error: %+v`, err)
return
}
}
}
// Shell executing.
if in.Shell != "" && gfile.Exists(in.Shell) {
if err = c.exeDockerShell(ctx, in.Shell); err != nil {
mlog.Debugf(`build docker failed with error: %+v`, err)
return
}
}
@@ -114,7 +130,7 @@ func (c cDocker) Index(ctx context.Context, in cDockerInput) (out *cDockerOutput
}
}
if len(dockerTags) == 0 {
dockerTags = []string{""}
dockerTags = []string{in.Tag}
}
for i, dockerTag := range dockerTags {
if i > 0 {

View File

@@ -1,14 +1,21 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"bytes"
"context"
"github.com/olekukonko/tablewriter"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/olekukonko/tablewriter"
"hotgo/internal/library/hggen/internal/utility/mlog"
)

View File

@@ -1,8 +1,17 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"context"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
@@ -19,7 +28,9 @@ type cFix struct {
}
type cFixInput struct {
g.Meta `name:"fix"`
g.Meta `name:"fix"`
Path string `name:"path" short:"p" brief:"directory path, it uses current working directory in default"`
Version string `name:"version" short:"v" brief:"custom specified version to fix, leave it empty to auto detect"`
}
type cFixOutput struct{}
@@ -30,35 +41,46 @@ type cFixItem struct {
}
func (c cFix) Index(ctx context.Context, in cFixInput) (out *cFixOutput, err error) {
mlog.Print(`start auto fixing...`)
defer mlog.Print(`done!`)
err = c.doFix()
if in.Path == "" {
in.Path = gfile.Pwd()
}
if in.Version == "" {
in.Version, err = c.autoDetectVersion(in)
if err != nil {
mlog.Fatal(err)
}
if in.Version == "" {
mlog.Print(`no GoFrame usage found, exit fixing`)
return
}
mlog.Debugf(`current GoFrame version auto detect "%s"`, in.Version)
}
if !gproc.IsChild() {
mlog.Printf(`start auto fixing directory path "%s"...`, in.Path)
defer mlog.Print(`done!`)
}
err = c.doFix(in)
return
}
func (c cFix) doFix() (err error) {
version, err := c.getVersion()
if err != nil {
mlog.Fatal(err)
}
if version == "" {
mlog.Print(`no GoFrame usage found, exit fixing`)
return
}
mlog.Debugf(`current GoFrame version found "%s"`, version)
func (c cFix) doFix(in cFixInput) (err error) {
var items = []cFixItem{
{Version: "v2.3", Func: c.doFixV23},
{Version: "v2.5", Func: c.doFixV25},
}
for _, item := range items {
if gstr.CompareVersionGo(version, item.Version) < 0 {
if gstr.CompareVersionGo(in.Version, item.Version) < 0 {
mlog.Debugf(
`current GoFrame version "%s" is lesser than "%s", nothing to do`,
version, item.Version,
`current GoFrame or contrib package version "%s" is lesser than "%s", nothing to do`,
in.Version, item.Version,
)
continue
}
if err = item.Func(version); err != nil {
if err = item.Func(in.Version); err != nil {
return
}
}
@@ -68,16 +90,47 @@ func (c cFix) doFix() (err error) {
// doFixV23 fixes code when upgrading to GoFrame v2.3.
func (c cFix) doFixV23(version string) error {
replaceFunc := func(path, content string) string {
// gdb.TX from struct to interface.
content = gstr.Replace(content, "*gdb.TX", "gdb.TX")
// function name changes for package gtcp/gudp.
if gstr.Contains(content, "/gf/v2/net/gtcp") || gstr.Contains(content, "/gf/v2/net/gudp") {
content = gstr.ReplaceByMap(content, g.MapStrStr{
".SetSendDeadline": ".SetDeadlineSend",
".SetReceiveDeadline": ".SetDeadlineRecv",
".SetReceiveBufferWait": ".SetBufferWaitRecv",
})
}
return content
}
return gfile.ReplaceDirFunc(replaceFunc, ".", "*.go", true)
}
func (c cFix) getVersion() (string, error) {
// doFixV25 fixes code when upgrading to GoFrame v2.5.
func (c cFix) doFixV25(version string) (err error) {
replaceFunc := func(path, content string) string {
content, err = c.doFixV25Content(content)
return content
}
return gfile.ReplaceDirFunc(replaceFunc, ".", "*.go", true)
}
func (c cFix) doFixV25Content(content string) (newContent string, err error) {
newContent = content
if gstr.Contains(content, `.BindHookHandlerByMap(`) {
var pattern = `\.BindHookHandlerByMap\((.+?), map\[string\]ghttp\.HandlerFunc`
newContent, err = gregex.ReplaceString(
pattern,
`.BindHookHandlerByMap($1, map[ghttp.HookName]ghttp.HandlerFunc`,
content,
)
}
return
}
func (c cFix) autoDetectVersion(in cFixInput) (string, error) {
var (
err error
path = "go.mod"
path = gfile.Join(in.Path, "go.mod")
version string
)
if !gfile.Exists(path) {
@@ -86,7 +139,7 @@ func (c cFix) getVersion() (string, error) {
err = gfile.ReadLines(path, func(line string) error {
array := gstr.SplitAndTrim(line, " ")
if len(array) > 0 {
if array[0] == gfPackage {
if gstr.HasPrefix(array[0], gfPackage) {
version = array[1]
}
}

View File

@@ -0,0 +1,20 @@
package cmd
import (
"fmt"
"testing"
"github.com/gogf/gf/v2/test/gtest"
)
func Test_Fix_doFixV25Content(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
var (
content = gtest.DataContent(`fix25_content.go.txt`)
f = cFix{}
)
newContent, err := f.doFixV25Content(content)
t.AssertNil(err)
fmt.Println(newContent)
})
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
@@ -13,6 +19,7 @@ type cGen struct {
g.Meta `name:"gen" brief:"{cGenBrief}" dc:"{cGenDc}"`
cGenDao
cGenEnums
cGenCtrl
cGenPb
cGenPbEntity
cGenService

View File

@@ -0,0 +1,15 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"hotgo/internal/library/hggen/internal/cmd/genctrl"
)
type (
cGenCtrl = genctrl.CGenCtrl
)

View File

@@ -1,8 +1,15 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
//_ "github.com/gogf/gf/contrib/drivers/clickhouse/v2"
//_ "github.com/gogf/gf/contrib/drivers/mssql/v2"
//_ "github.com/gogf/gf/contrib/drivers/mysql/v2"
_ "github.com/gogf/gf/contrib/drivers/mysql/v2"
//_ "github.com/gogf/gf/contrib/drivers/oracle/v2"
//_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
//_ "github.com/gogf/gf/contrib/drivers/sqlite/v2"

View File

@@ -1,79 +1,13 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"context"
"fmt"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/genv"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gproc"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
import "hotgo/internal/library/hggen/internal/cmd/genpb"
type (
cGenPb struct{}
cGenPbInput struct {
g.Meta `name:"pb" brief:"parse proto files and generate protobuf go files"`
}
cGenPbOutput struct{}
cGenPb = genpb.CGenPb
)
func (c cGenPb) Pb(ctx context.Context, in cGenPbInput) (out *cGenPbOutput, err error) {
// Necessary check.
if gproc.SearchBinary("protoc") == "" {
mlog.Fatalf(`command "protoc" not found in your environment, please install protoc first to proceed this command`)
}
// protocol fold checks.
protoFolder := "protocol"
if !gfile.Exists(protoFolder) {
mlog.Fatalf(`proto files folder "%s" does not exist`, protoFolder)
}
// folder scanning.
files, err := gfile.ScanDirFile(protoFolder, "*.proto", true)
if err != nil {
mlog.Fatal(err)
}
if len(files) == 0 {
mlog.Fatalf(`no proto files found in folder "%s"`, protoFolder)
}
dirSet := gset.NewStrSet()
for _, file := range files {
dirSet.Add(gfile.Dir(file))
}
var (
servicePath = gfile.RealPath(".")
goPathSrc = gfile.RealPath(gfile.Join(genv.Get("GOPATH").String(), "src"))
)
dirSet.Iterator(func(protoDirPath string) bool {
parsingCommand := fmt.Sprintf(
"protoc --gofast_out=plugins=grpc:. %s/*.proto -I%s",
protoDirPath,
servicePath,
)
if goPathSrc != "" {
parsingCommand += " -I" + goPathSrc
}
mlog.Print(parsingCommand)
if output, err := gproc.ShellExec(ctx, parsingCommand); err != nil {
mlog.Print(output)
mlog.Fatal(err)
}
return true
})
// Custom replacement.
//pbFolder := "protobuf"
//_, _ = gfile.ScanDirFileFunc(pbFolder, "*.go", true, func(path string) string {
// content := gfile.GetContents(path)
// content = gstr.ReplaceByArray(content, g.SliceStr{
// `gtime "gtime"`, `gtime "github.com/gogf/gf/v2/os/gtime"`,
// })
// _ = gfile.PutContents(path, content)
// utils.GoFmt(path)
// return path
//})
mlog.Print("done!")
return
}

View File

@@ -1,411 +1,13 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"bytes"
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gtag"
"github.com/olekukonko/tablewriter"
"hotgo/internal/library/hggen/internal/consts"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
import "hotgo/internal/library/hggen/internal/cmd/genpbentity"
type (
cGenPbEntity struct{}
cGenPbEntityInput struct {
g.Meta `name:"pbentity" config:"{cGenPbEntityConfig}" brief:"{cGenPbEntityBrief}" eg:"{cGenPbEntityEg}" ad:"{cGenPbEntityAd}"`
Path string `name:"path" short:"p" brief:"{cGenPbEntityBriefPath}"`
Package string `name:"package" short:"k" brief:"{cGenPbEntityBriefPackage}"`
Link string `name:"link" short:"l" brief:"{cGenPbEntityBriefLink}"`
Tables string `name:"tables" short:"t" brief:"{cGenPbEntityBriefTables}"`
Prefix string `name:"prefix" short:"f" brief:"{cGenPbEntityBriefPrefix}"`
RemovePrefix string `name:"removePrefix" short:"r" brief:"{cGenPbEntityBriefRemovePrefix}"`
NameCase string `name:"nameCase" short:"n" brief:"{cGenPbEntityBriefNameCase}" d:"Camel"`
JsonCase string `name:"jsonCase" short:"j" brief:"{cGenPbEntityBriefJsonCase}" d:"CamelLower"`
Option string `name:"option" short:"o" brief:"{cGenPbEntityBriefOption}"`
}
cGenPbEntityOutput struct{}
cGenPbEntityInternalInput struct {
cGenPbEntityInput
TableName string // TableName specifies the table name of the table.
NewTableName string // NewTableName specifies the prefix-stripped name of the table.
}
cGenPbEntity = genpbentity.CGenPbEntity
)
const (
cGenPbEntityConfig = `gfcli.gen.pbentity`
cGenPbEntityBrief = `generate entity message files in protobuf3 format`
cGenPbEntityEg = `
gf gen pbentity
gf gen pbentity -l "mysql:root:12345678@tcp(127.0.0.1:3306)/test"
gf gen pbentity -p ./protocol/demos/entity -t user,user_detail,user_login
gf gen pbentity -r user_
`
cGenPbEntityAd = `
CONFIGURATION SUPPORT
Options are also supported by configuration file.
It's suggested using configuration file instead of command line arguments making producing.
The configuration node name is "gf.gen.pbentity", which also supports multiple databases, for example(config.yaml):
gfcli:
gen:
- pbentity:
link: "mysql:root:12345678@tcp(127.0.0.1:3306)/test"
path: "protocol/demos/entity"
tables: "order,products"
package: "demos"
- pbentity:
link: "mysql:root:12345678@tcp(127.0.0.1:3306)/primary"
path: "protocol/demos/entity"
prefix: "primary_"
tables: "user, userDetail"
package: "demos"
option: |
option go_package = "protobuf/demos";
option java_package = "protobuf/demos";
option php_namespace = "protobuf/demos";
`
cGenPbEntityBriefPath = `directory path for generated files`
cGenPbEntityBriefPackage = `package name for all entity proto files`
cGenPbEntityBriefLink = `database configuration, the same as the ORM configuration of GoFrame`
cGenPbEntityBriefTables = `generate models only for given tables, multiple table names separated with ','`
cGenPbEntityBriefPrefix = `add specified prefix for all entity names and entity proto files`
cGenPbEntityBriefRemovePrefix = `remove specified prefix of the table, multiple prefix separated with ','`
cGenPbEntityBriefOption = `extra protobuf options`
cGenPbEntityBriefGroup = `
specifying the configuration group name of database for generated ORM instance,
it's not necessary and the default value is "default"
`
cGenPbEntityBriefNameCase = `
case for message attribute names, default is "Camel":
| Case | Example |
|---------------- |--------------------|
| Camel | AnyKindOfString |
| CamelLower | anyKindOfString | default
| Snake | any_kind_of_string |
| SnakeScreaming | ANY_KIND_OF_STRING |
| SnakeFirstUpper | rgb_code_md5 |
| Kebab | any-kind-of-string |
| KebabScreaming | ANY-KIND-OF-STRING |
`
cGenPbEntityBriefJsonCase = `
case for message json tag, cases are the same as "nameCase", default "CamelLower".
set it to "none" to ignore json tag generating.
`
)
func init() {
gtag.Sets(g.MapStrStr{
`cGenPbEntityConfig`: cGenPbEntityConfig,
`cGenPbEntityBrief`: cGenPbEntityBrief,
`cGenPbEntityEg`: cGenPbEntityEg,
`cGenPbEntityAd`: cGenPbEntityAd,
`cGenPbEntityBriefPath`: cGenPbEntityBriefPath,
`cGenPbEntityBriefPackage`: cGenPbEntityBriefPackage,
`cGenPbEntityBriefLink`: cGenPbEntityBriefLink,
`cGenPbEntityBriefTables`: cGenPbEntityBriefTables,
`cGenPbEntityBriefPrefix`: cGenPbEntityBriefPrefix,
`cGenPbEntityBriefRemovePrefix`: cGenPbEntityBriefRemovePrefix,
`cGenPbEntityBriefGroup`: cGenPbEntityBriefGroup,
`cGenPbEntityBriefNameCase`: cGenPbEntityBriefNameCase,
`cGenPbEntityBriefJsonCase`: cGenPbEntityBriefJsonCase,
`cGenPbEntityBriefOption`: cGenPbEntityBriefOption,
})
}
func (c cGenPbEntity) PbEntity(ctx context.Context, in cGenPbEntityInput) (out *cGenPbEntityOutput, err error) {
var (
config = g.Cfg()
)
if config.Available(ctx) {
v := config.MustGet(ctx, cGenPbEntityConfig)
if v.IsSlice() {
for i := 0; i < len(v.Interfaces()); i++ {
doGenPbEntityForArray(ctx, i, in)
}
} else {
doGenPbEntityForArray(ctx, -1, in)
}
} else {
doGenPbEntityForArray(ctx, -1, in)
}
mlog.Print("done!")
return
}
func doGenPbEntityForArray(ctx context.Context, index int, in cGenPbEntityInput) {
var (
err error
db gdb.DB
)
if index >= 0 {
err = g.Cfg().MustGet(
ctx,
fmt.Sprintf(`%s.%d`, cGenPbEntityConfig, index),
).Scan(&in)
if err != nil {
mlog.Fatalf(`invalid configuration of "%s": %+v`, cGenPbEntityConfig, err)
}
}
if in.Package == "" {
mlog.Fatal("package name should not be empty")
}
removePrefixArray := gstr.SplitAndTrim(in.RemovePrefix, ",")
// It uses user passed database configuration.
if in.Link != "" {
var (
tempGroup = gtime.TimestampNanoStr()
match, _ = gregex.MatchString(`([a-z]+):(.+)`, in.Link)
)
if len(match) == 3 {
gdb.AddConfigNode(tempGroup, gdb.ConfigNode{
Type: gstr.Trim(match[1]),
Link: gstr.Trim(match[2]),
})
db, _ = gdb.Instance(tempGroup)
}
} else {
db = g.DB()
}
if db == nil {
mlog.Fatal("database initialization failed")
}
tableNames := ([]string)(nil)
if in.Tables != "" {
tableNames = gstr.SplitAndTrim(in.Tables, ",")
} else {
tableNames, err = db.Tables(context.TODO())
if err != nil {
mlog.Fatalf("fetching tables failed: \n %v", err)
}
}
for _, tableName := range tableNames {
newTableName := tableName
for _, v := range removePrefixArray {
newTableName = gstr.TrimLeftStr(newTableName, v, 1)
}
generatePbEntityContentFile(ctx, db, cGenPbEntityInternalInput{
cGenPbEntityInput: in,
TableName: tableName,
NewTableName: newTableName,
})
}
}
// generatePbEntityContentFile generates the protobuf files for given table.
func generatePbEntityContentFile(ctx context.Context, db gdb.DB, in cGenPbEntityInternalInput) {
fieldMap, err := db.TableFields(ctx, in.TableName)
if err != nil {
mlog.Fatalf("fetching tables fields failed for table '%s':\n%v", in.TableName, err)
}
// Change the `newTableName` if `Prefix` is given.
newTableName := "Entity_" + in.Prefix + in.NewTableName
var (
tableNameCamelCase = gstr.CaseCamel(newTableName)
tableNameSnakeCase = gstr.CaseSnake(newTableName)
entityMessageDefine = generateEntityMessageDefinition(tableNameCamelCase, fieldMap, in)
fileName = gstr.Trim(tableNameSnakeCase, "-_.")
path = gfile.Join(in.Path, fileName+".proto")
)
entityContent := gstr.ReplaceByMap(getTplPbEntityContent(""), g.MapStrStr{
"{PackageName}": in.Package,
"{OptionContent}": in.Option,
"{EntityMessage}": entityMessageDefine,
})
if err := gfile.PutContents(path, strings.TrimSpace(entityContent)); err != nil {
mlog.Fatalf("writing content to '%s' failed: %v", path, err)
} else {
mlog.Print("generated:", path)
}
}
// generateEntityMessageDefinition generates and returns the message definition for specified table.
func generateEntityMessageDefinition(entityName string, fieldMap map[string]*gdb.TableField, in cGenPbEntityInternalInput) string {
var (
buffer = bytes.NewBuffer(nil)
array = make([][]string, len(fieldMap))
names = sortFieldKeyForPbEntity(fieldMap)
)
for index, name := range names {
array[index] = generateMessageFieldForPbEntity(index+1, fieldMap[name], in)
}
tw := tablewriter.NewWriter(buffer)
tw.SetBorder(false)
tw.SetRowLine(false)
tw.SetAutoWrapText(false)
tw.SetColumnSeparator("")
tw.AppendBulk(array)
tw.Render()
stContent := buffer.String()
// Let's do this hack of table writer for indent!
stContent = gstr.Replace(stContent, " #", "")
buffer.Reset()
buffer.WriteString(fmt.Sprintf("message %s {\n", entityName))
buffer.WriteString(stContent)
buffer.WriteString("}")
return buffer.String()
}
// generateMessageFieldForPbEntity generates and returns the message definition for specified field.
func generateMessageFieldForPbEntity(index int, field *gdb.TableField, in cGenPbEntityInternalInput) []string {
var (
typeName string
comment string
jsonTagStr string
)
t, _ := gregex.ReplaceString(`\(.+\)`, "", field.Type)
t = gstr.Split(gstr.Trim(t), " ")[0]
t = gstr.ToLower(t)
switch t {
case "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob":
typeName = "bytes"
case "bit", "int", "tinyint", "small_int", "smallint", "medium_int", "mediumint", "serial":
if gstr.ContainsI(field.Type, "unsigned") {
typeName = "uint32"
} else {
typeName = "int32"
}
case "int8", "big_int", "bigint", "bigserial":
if gstr.ContainsI(field.Type, "unsigned") {
typeName = "uint64"
} else {
typeName = "int64"
}
case "real":
typeName = "float"
case "float", "double", "decimal", "smallmoney":
typeName = "double"
case "bool":
typeName = "bool"
case "datetime", "timestamp", "date", "time":
typeName = "int64"
default:
// Auto detecting type.
switch {
case strings.Contains(t, "int"):
typeName = "int"
case strings.Contains(t, "text") || strings.Contains(t, "char"):
typeName = "string"
case strings.Contains(t, "float") || strings.Contains(t, "double"):
typeName = "double"
case strings.Contains(t, "bool"):
typeName = "bool"
case strings.Contains(t, "binary") || strings.Contains(t, "blob"):
typeName = "bytes"
case strings.Contains(t, "date") || strings.Contains(t, "time"):
typeName = "int64"
default:
typeName = "string"
}
}
comment = gstr.ReplaceByArray(field.Comment, g.SliceStr{
"\n", " ",
"\r", " ",
})
comment = gstr.Trim(comment)
comment = gstr.Replace(comment, `\n`, " ")
comment, _ = gregex.ReplaceString(`\s{2,}`, ` `, comment)
if jsonTagName := formatCase(field.Name, in.JsonCase); jsonTagName != "" {
jsonTagStr = fmt.Sprintf(`[(gogoproto.jsontag) = "%s"]`, jsonTagName)
// beautiful indent.
if index < 10 {
// 3 spaces
jsonTagStr = " " + jsonTagStr
} else if index < 100 {
// 2 spaces
jsonTagStr = " " + jsonTagStr
} else {
// 1 spaces
jsonTagStr = " " + jsonTagStr
}
}
return []string{
" #" + typeName,
" #" + formatCase(field.Name, in.NameCase),
" #= " + gconv.String(index) + jsonTagStr + ";",
" #" + fmt.Sprintf(`// %s`, comment),
}
}
func getTplPbEntityContent(tplEntityPath string) string {
if tplEntityPath != "" {
return gfile.GetContents(tplEntityPath)
}
return consts.TemplatePbEntityMessageContent
}
// formatCase call gstr.Case* function to convert the s to specified case.
func formatCase(str, caseStr string) string {
switch gstr.ToLower(caseStr) {
case gstr.ToLower("Camel"):
return gstr.CaseCamel(str)
case gstr.ToLower("CamelLower"):
return gstr.CaseCamelLower(str)
case gstr.ToLower("Kebab"):
return gstr.CaseKebab(str)
case gstr.ToLower("KebabScreaming"):
return gstr.CaseKebabScreaming(str)
case gstr.ToLower("Snake"):
return gstr.CaseSnake(str)
case gstr.ToLower("SnakeFirstUpper"):
return gstr.CaseSnakeFirstUpper(str)
case gstr.ToLower("SnakeScreaming"):
return gstr.CaseSnakeScreaming(str)
case "none":
return ""
}
return str
}
func sortFieldKeyForPbEntity(fieldMap map[string]*gdb.TableField) []string {
names := make(map[int]string)
for _, field := range fieldMap {
names[field.Index] = field.Name
}
var (
result = make([]string, len(names))
i = 0
j = 0
)
for {
if len(names) == 0 {
break
}
if val, ok := names[i]; ok {
result[j] = val
j++
delete(names, i)
}
i++
}
return result
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (

View File

@@ -1,9 +1,14 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"context"
"fmt"
"github.com/gogf/gf/v2/text/gstr"
"os"
"strings"
@@ -12,6 +17,7 @@ import (
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/os/gres"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gtag"
"hotgo/internal/library/hggen/internal/utility/allyes"

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (

View File

@@ -1,9 +1,16 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"context"
"fmt"
"runtime"
"strings"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/frame/g"
@@ -154,7 +161,7 @@ func (app *cRunApp) Run(ctx context.Context) {
if runtime.GOOS == "windows" {
// Special handling for windows platform.
// DO NOT USE "cmd /c" command.
process = gproc.NewProcess(runCommand, nil)
process = gproc.NewProcess(outputPath, strings.Fields(app.Args))
} else {
process = gproc.NewProcessCmd(runCommand, nil)
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (

View File

@@ -1,15 +1,26 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
"context"
"fmt"
"runtime"
"github.com/minio/selfupdate"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gtag"
"hotgo/internal/library/hggen/internal/utility/mlog"
"hotgo/internal/library/hggen/internal/utility/utils"
)
var (
@@ -21,12 +32,12 @@ type cUp struct {
}
const (
gfPackage = `github.com/gogf/gf/v2`
gfPackage = `github.com/gogf/gf/`
cUpEg = `
gf up
gf up -a
gf up -c
gf up -f -c
gf up -cf
`
)
@@ -39,51 +50,90 @@ func init() {
type cUpInput struct {
g.Meta `name:"up" config:"gfcli.up"`
All bool `name:"all" short:"a" brief:"upgrade both version and cli, auto fix codes" orphan:"true"`
Fix bool `name:"fix" short:"f" brief:"auto fix codes" orphan:"true"`
Cli bool `name:"cli" short:"c" brief:"also upgrade CLI tool (not supported yet)" orphan:"true"`
Cli bool `name:"cli" short:"c" brief:"also upgrade CLI tool" orphan:"true"`
Fix bool `name:"fix" short:"f" brief:"auto fix codes(it only make sense if cli is to be upgraded)" orphan:"true"`
}
type cUpOutput struct{}
func (c cUp) Index(ctx context.Context, in cUpInput) (out *cUpOutput, err error) {
defer mlog.Print(`done!`)
defer func() {
if err == nil {
mlog.Print()
mlog.Print(`👏congratulations! you've upgraded to the latest version of GoFrame! enjoy it!👏`)
mlog.Print()
}
}()
var doUpgradeVersionOut *doUpgradeVersionOutput
if in.All {
in.Cli = true
in.Fix = true
}
if err = c.doUpgradeVersion(ctx); err != nil {
if doUpgradeVersionOut, err = c.doUpgradeVersion(ctx, in); err != nil {
return nil, err
}
if in.Fix {
if err = c.doAutoFixing(ctx); err != nil {
if in.Cli {
if err = c.doUpgradeCLI(ctx); err != nil {
return nil, err
}
}
//if in.Cli {
// if err = c.doUpgradeCLI(ctx); err != nil {
// return nil, err
// }
//}
if in.Cli && in.Fix {
if doUpgradeVersionOut != nil && len(doUpgradeVersionOut.Items) > 0 {
upgradedPathSet := gset.NewStrSet()
for _, item := range doUpgradeVersionOut.Items {
if !upgradedPathSet.AddIfNotExist(item.DirPath) {
continue
}
if err = c.doAutoFixing(ctx, item.DirPath, item.Version); err != nil {
return nil, err
}
}
}
}
return
}
func (c cUp) doUpgradeVersion(ctx context.Context) (err error) {
type doUpgradeVersionOutput struct {
Items []doUpgradeVersionOutputItem
}
type doUpgradeVersionOutputItem struct {
DirPath string
Version string
}
func (c cUp) doUpgradeVersion(ctx context.Context, in cUpInput) (out *doUpgradeVersionOutput, err error) {
mlog.Print(`start upgrading version...`)
out = &doUpgradeVersionOutput{
Items: make([]doUpgradeVersionOutputItem, 0),
}
type Package struct {
Name string
Version string
}
var (
dir = gfile.Pwd()
temp string
path = gfile.Join(dir, "go.mod")
temp string
dirPath = gfile.Pwd()
goModPath = gfile.Join(dirPath, "go.mod")
)
// It recursively upgrades the go.mod from sub folder to its parent folders.
for {
if gfile.Exists(path) {
var packages []string
err = gfile.ReadLines(path, func(line string) error {
if gfile.Exists(goModPath) {
var packages []Package
err = gfile.ReadLines(goModPath, func(line string) error {
line = gstr.Trim(line)
line = gstr.TrimLeftStr(line, "require ")
line = gstr.Trim(line)
if gstr.HasPrefix(line, gfPackage) {
pkg := gstr.Explode(" ", line)[0]
packages = append(packages, pkg)
array := gstr.SplitAndTrim(line, " ")
packages = append(packages, Package{
Name: array[0],
Version: array[1],
})
}
return nil
})
@@ -91,31 +141,79 @@ func (c cUp) doUpgradeVersion(ctx context.Context) (err error) {
return
}
for _, pkg := range packages {
mlog.Printf(`upgrading %s`, pkg)
command := fmt.Sprintf(`go get -u %s@latest`, pkg)
mlog.Printf(`upgrading "%s" from "%s" to "latest"`, pkg.Name, pkg.Version)
// go get -u
command := fmt.Sprintf(`cd %s && go get -u %s@latest`, dirPath, pkg.Name)
if err = gproc.ShellRun(ctx, command); err != nil {
return
}
// go mod tidy
if err = utils.GoModTidy(ctx, dirPath); err != nil {
return nil, err
}
out.Items = append(out.Items, doUpgradeVersionOutputItem{
DirPath: dirPath,
Version: pkg.Version,
})
}
return
}
temp = gfile.Dir(dir)
if temp == "" || temp == dir {
temp = gfile.Dir(dirPath)
if temp == "" || temp == dirPath {
return
}
dir = temp
path = gfile.Join(dir, "go.mod")
dirPath = temp
goModPath = gfile.Join(dirPath, "go.mod")
}
}
// doUpgradeCLI downloads the new version binary with process.
func (c cUp) doUpgradeCLI(ctx context.Context) (err error) {
mlog.Print(`start upgrading cli...`)
var (
downloadUrl = fmt.Sprintf(
`https://github.com/gogf/gf/releases/latest/download/gf_%s_%s`,
runtime.GOOS, runtime.GOARCH,
)
localSaveFilePath = gfile.SelfPath() + "~"
)
if runtime.GOOS == "windows" {
downloadUrl += ".exe"
}
mlog.Printf(`start downloading "%s" to "%s", it may take some time`, downloadUrl, localSaveFilePath)
err = utils.HTTPDownloadFileWithPercent(downloadUrl, localSaveFilePath)
if err != nil {
return err
}
defer func() {
mlog.Printf(`new version cli binary is successfully installed to "%s"`, gfile.SelfPath())
mlog.Printf(`remove temporary buffer file "%s"`, localSaveFilePath)
_ = gfile.Remove(localSaveFilePath)
}()
// It fails if file not exist or its size is less than 1MB.
if !gfile.Exists(localSaveFilePath) || gfile.Size(localSaveFilePath) < 1024*1024 {
mlog.Fatalf(`download "%s" to "%s" failed`, downloadUrl, localSaveFilePath)
}
newFile, err := gfile.Open(localSaveFilePath)
if err != nil {
return err
}
// selfupdate
err = selfupdate.Apply(newFile, selfupdate.Options{})
if err != nil {
return err
}
return
}
func (c cUp) doAutoFixing(ctx context.Context) (err error) {
mlog.Print(`start auto fixing...`)
err = cFix{}.doFix()
func (c cUp) doAutoFixing(ctx context.Context, dirPath string, version string) (err error) {
mlog.Printf(`auto fixing directory path "%s" from version "%s" ...`, dirPath, version)
command := fmt.Sprintf(`gf fix -p %s`, dirPath)
_ = gproc.ShellRun(ctx, command)
return
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package cmd
import (
@@ -64,6 +70,8 @@ func (c cVersion) getGFVersionOfCurrentProject() (string, error) {
if gfile.Exists(goModPath) {
lines := gstr.SplitAndTrim(gfile.GetContents(goModPath), "\n")
for _, line := range lines {
line = gstr.Trim(line)
line = gstr.TrimLeftStr(line, "require ")
line = gstr.Trim(line)
// Version 1.
match, err := gregex.MatchString(`^github\.com/gogf/gf\s+(.+)$`, line)

View File

@@ -0,0 +1,233 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genctrl
import (
"context"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gtag"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
const (
CGenCtrlConfig = `gfcli.gen.ctrl`
CGenCtrlUsage = `gf gen ctrl [OPTION]`
CGenCtrlBrief = `parse api definitions to generate controller/sdk go files`
CGenCtrlEg = `
gf gen ctrl
`
CGenCtrlBriefSrcFolder = `source folder path to be parsed. default: api`
CGenCtrlBriefDstFolder = `destination folder path storing automatically generated go files. default: internal/controller`
CGenCtrlBriefWatchFile = `used in file watcher, it re-generates go files only if given file is under srcFolder`
CGenCtrlBriefSdkPath = `also generate SDK go files for api definitions to specified directory`
CGenCtrlBriefSdkStdVersion = `use standard version prefix for generated sdk request path`
CGenCtrlBriefSdkNoV1 = `do not add version suffix for interface module name if version is v1`
CGenCtrlBriefClear = `auto delete generated and unimplemented controller go files if api definitions are missing`
)
const (
PatternApiDefinition = `type\s+(\w+)Req\s+struct\s+{`
PatternCtrlDefinition = `func\s+\(.+?\)\s+\w+\(.+?\*(\w+)\.(\w+)Req\)\s+\(.+?\*(\w+)\.(\w+)Res,\s+\w+\s+error\)\s+{`
)
const (
genCtrlFileLockSeconds = 10
)
func init() {
gtag.Sets(g.MapStrStr{
`CGenCtrlConfig`: CGenCtrlConfig,
`CGenCtrlUsage`: CGenCtrlUsage,
`CGenCtrlBrief`: CGenCtrlBrief,
`CGenCtrlEg`: CGenCtrlEg,
`CGenCtrlBriefSrcFolder`: CGenCtrlBriefSrcFolder,
`CGenCtrlBriefDstFolder`: CGenCtrlBriefDstFolder,
`CGenCtrlBriefWatchFile`: CGenCtrlBriefWatchFile,
`CGenCtrlBriefSdkPath`: CGenCtrlBriefSdkPath,
`CGenCtrlBriefSdkStdVersion`: CGenCtrlBriefSdkStdVersion,
`CGenCtrlBriefSdkNoV1`: CGenCtrlBriefSdkNoV1,
`CGenCtrlBriefClear`: CGenCtrlBriefClear,
})
}
type (
CGenCtrl struct{}
CGenCtrlInput struct {
g.Meta `name:"ctrl" config:"{CGenCtrlConfig}" usage:"{CGenCtrlUsage}" brief:"{CGenCtrlBrief}" eg:"{CGenCtrlEg}"`
SrcFolder string `short:"s" name:"srcFolder" brief:"{CGenCtrlBriefSrcFolder}" d:"api"`
DstFolder string `short:"d" name:"dstFolder" brief:"{CGenCtrlBriefDstFolder}" d:"internal/controller"`
WatchFile string `short:"w" name:"watchFile" brief:"{CGenCtrlBriefWatchFile}"`
SdkPath string `short:"k" name:"sdkPath" brief:"{CGenCtrlBriefSdkPath}"`
SdkStdVersion bool `short:"v" name:"sdkStdVersion" brief:"{CGenCtrlBriefSdkStdVersion}" orphan:"true"`
SdkNoV1 bool `short:"n" name:"sdkNoV1" brief:"{CGenCtrlBriefSdkNoV1}" orphan:"true"`
Clear bool `short:"c" name:"clear" brief:"{CGenCtrlBriefClear}" orphan:"true"`
}
CGenCtrlOutput struct{}
)
func (c CGenCtrl) Ctrl(ctx context.Context, in CGenCtrlInput) (out *CGenCtrlOutput, err error) {
if in.WatchFile != "" {
err = c.generateByWatchFile(
in.WatchFile, in.SdkPath, in.SdkStdVersion, in.SdkNoV1, in.Clear,
)
mlog.Print(`done!`)
return
}
if !gfile.Exists(in.SrcFolder) {
mlog.Fatalf(`source folder path "%s" does not exist`, in.SrcFolder)
}
// retrieve all api modules.
apiModuleFolderPaths, err := gfile.ScanDir(in.SrcFolder, "*", false)
if err != nil {
return nil, err
}
for _, apiModuleFolderPath := range apiModuleFolderPaths {
if !gfile.IsDir(apiModuleFolderPath) {
continue
}
// generate go files by api module.
var (
module = gfile.Basename(apiModuleFolderPath)
dstModuleFolderPath = gfile.Join(in.DstFolder, module)
)
err = c.generateByModule(
apiModuleFolderPath, dstModuleFolderPath, in.SdkPath,
in.SdkStdVersion, in.SdkNoV1, in.Clear,
)
if err != nil {
return nil, err
}
}
mlog.Print(`done!`)
return
}
func (c CGenCtrl) generateByWatchFile(watchFile, sdkPath string, sdkStdVersion, sdkNoV1, clear bool) (err error) {
// File lock to avoid multiple processes.
var (
flockFilePath = gfile.Temp("gf.cli.gen.service.lock")
flockContent = gfile.GetContents(flockFilePath)
)
if flockContent != "" {
if gtime.Timestamp()-gconv.Int64(flockContent) < genCtrlFileLockSeconds {
// If another generating process is running, it just exits.
mlog.Debug(`another "gen service" process is running, exit`)
return
}
}
defer gfile.Remove(flockFilePath)
_ = gfile.PutContents(flockFilePath, gtime.TimestampStr())
// check this updated file is an api file.
// watch file should be in standard goframe project structure.
var (
apiVersionPath = gfile.Dir(watchFile)
apiModuleFolderPath = gfile.Dir(apiVersionPath)
shouldBeNameOfAPi = gfile.Basename(gfile.Dir(apiModuleFolderPath))
)
if shouldBeNameOfAPi != "api" {
return nil
}
// watch file should have api definitions.
if gfile.Exists(watchFile) {
if !gregex.IsMatchString(PatternApiDefinition, gfile.GetContents(watchFile)) {
return nil
}
}
var (
projectRootPath = gfile.Dir(gfile.Dir(apiModuleFolderPath))
module = gfile.Basename(apiModuleFolderPath)
dstModuleFolderPath = gfile.Join(projectRootPath, "internal", "controller", module)
)
return c.generateByModule(
apiModuleFolderPath, dstModuleFolderPath, sdkPath, sdkStdVersion, sdkNoV1, clear,
)
}
// parseApiModule parses certain api and generate associated go files by certain module, not all api modules.
func (c CGenCtrl) generateByModule(
apiModuleFolderPath, dstModuleFolderPath, sdkPath string,
sdkStdVersion, sdkNoV1, clear bool,
) (err error) {
// parse src and dst folder go files.
apiItemsInSrc, err := c.getApiItemsInSrc(apiModuleFolderPath)
if err != nil {
return err
}
apiItemsInDst, err := c.getApiItemsInDst(dstModuleFolderPath)
if err != nil {
return err
}
// generate api interface go files.
if err = newApiInterfaceGenerator().Generate(apiModuleFolderPath, apiItemsInSrc); err != nil {
return
}
// generate controller go files.
// api filtering for already implemented api controllers.
var (
alreadyImplementedCtrlSet = gset.NewStrSet()
toBeImplementedApiItems = make([]apiItem, 0)
)
for _, item := range apiItemsInDst {
alreadyImplementedCtrlSet.Add(item.String())
}
for _, item := range apiItemsInSrc {
if alreadyImplementedCtrlSet.Contains(item.String()) {
continue
}
toBeImplementedApiItems = append(toBeImplementedApiItems, item)
}
if len(toBeImplementedApiItems) > 0 {
err = newControllerGenerator().Generate(dstModuleFolderPath, toBeImplementedApiItems)
if err != nil {
return
}
}
// delete unimplemented controllers if api definitions are missing.
if clear {
var (
apiDefinitionSet = gset.NewStrSet()
extraApiItemsInCtrl = make([]apiItem, 0)
)
for _, item := range apiItemsInSrc {
apiDefinitionSet.Add(item.String())
}
for _, item := range apiItemsInDst {
if apiDefinitionSet.Contains(item.String()) {
continue
}
extraApiItemsInCtrl = append(extraApiItemsInCtrl, item)
}
if len(extraApiItemsInCtrl) > 0 {
err = newControllerClearer().Clear(dstModuleFolderPath, extraApiItemsInCtrl)
if err != nil {
return
}
}
}
// generate sdk go files.
if sdkPath != "" {
if err = newApiSdkGenerator().Generate(apiItemsInSrc, sdkPath, sdkStdVersion, sdkNoV1); err != nil {
return
}
}
return
}

View File

@@ -0,0 +1,22 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genctrl
import "github.com/gogf/gf/v2/text/gstr"
type apiItem struct {
Import string `eg:"demo.com/api/user/v1"`
Module string `eg:"user"`
Version string `eg:"v1"`
MethodName string `eg:"GetList"`
}
func (a apiItem) String() string {
return gstr.Join([]string{
a.Import, a.Module, a.Version, a.MethodName,
}, ",")
}

View File

@@ -0,0 +1,136 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genctrl
import (
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"hotgo/internal/library/hggen/internal/utility/utils"
)
func (c CGenCtrl) getApiItemsInSrc(apiModuleFolderPath string) (items []apiItem, err error) {
var (
fileContent string
importPath string
)
// The second level folders: versions.
apiVersionFolderPaths, err := gfile.ScanDir(apiModuleFolderPath, "*", false)
if err != nil {
return nil, err
}
for _, apiVersionFolderPath := range apiVersionFolderPaths {
if !gfile.IsDir(apiVersionFolderPath) {
continue
}
// The second level folders: versions.
apiFileFolderPaths, err := gfile.ScanDir(apiVersionFolderPath, "*.go", false)
if err != nil {
return nil, err
}
importPath = utils.GetImportPath(apiVersionFolderPath)
for _, apiFileFolderPath := range apiFileFolderPaths {
if gfile.IsDir(apiFileFolderPath) {
continue
}
fileContent = gfile.GetContents(apiFileFolderPath)
matches, err := gregex.MatchAllString(PatternApiDefinition, fileContent)
if err != nil {
return nil, err
}
for _, match := range matches {
item := apiItem{
Import: gstr.Trim(importPath, `"`),
Module: gfile.Basename(apiModuleFolderPath),
Version: gfile.Basename(apiVersionFolderPath),
MethodName: match[1],
}
items = append(items, item)
}
}
}
return
}
func (c CGenCtrl) getApiItemsInDst(dstFolder string) (items []apiItem, err error) {
if !gfile.Exists(dstFolder) {
return nil, nil
}
type importItem struct {
Path string
Alias string
}
var fileContent string
filePaths, err := gfile.ScanDir(dstFolder, "*.go", true)
if err != nil {
return nil, err
}
for _, filePath := range filePaths {
fileContent = gfile.GetContents(filePath)
match, err := gregex.MatchString(`import\s+\(([\s\S]+?)\)`, fileContent)
if err != nil {
return nil, err
}
if len(match) < 2 {
continue
}
var (
array []string
importItems []importItem
importLines = gstr.SplitAndTrim(match[1], "\n")
module = gfile.Basename(gfile.Dir(filePath))
)
// retrieve all imports.
for _, importLine := range importLines {
array = gstr.SplitAndTrim(importLine, " ")
if len(array) == 2 {
importItems = append(importItems, importItem{
Path: gstr.Trim(array[1], `"`),
Alias: array[0],
})
} else {
importItems = append(importItems, importItem{
Path: gstr.Trim(array[0], `"`),
})
}
}
// retrieve all api usages.
matches, err := gregex.MatchAllString(PatternCtrlDefinition, fileContent)
if err != nil {
return nil, err
}
for _, match = range matches {
// try to find the import path of the api.
var (
importPath string
version = match[1]
methodName = match[2] // not the function name, but the method name in api definition.
)
for _, item := range importItems {
if item.Alias != "" {
if item.Alias == version {
importPath = item.Path
break
}
continue
}
if gfile.Basename(item.Path) == version {
importPath = item.Path
break
}
}
item := apiItem{
Import: gstr.Trim(importPath, `"`),
Module: module,
Version: gfile.Basename(importPath),
MethodName: methodName,
}
items = append(items, item)
}
}
return
}

View File

@@ -0,0 +1,137 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genctrl
import (
"fmt"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gstr"
"hotgo/internal/library/hggen/internal/consts"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
type controllerGenerator struct{}
func newControllerGenerator() *controllerGenerator {
return &controllerGenerator{}
}
func (c *controllerGenerator) Generate(dstModuleFolderPath string, apiModuleApiItems []apiItem) (err error) {
var (
doneApiItemSet = gset.NewStrSet()
)
for _, item := range apiModuleApiItems {
if doneApiItemSet.Contains(item.String()) {
continue
}
// retrieve all api items of the same module.
subItems := c.getSubItemsByModuleAndVersion(apiModuleApiItems, item.Module, item.Version)
if err = c.doGenerateCtrlNewByModuleAndVersion(
dstModuleFolderPath, item.Module, item.Version, gfile.Dir(item.Import),
); err != nil {
return
}
for _, subItem := range subItems {
if err = c.doGenerateCtrlItem(dstModuleFolderPath, subItem); err != nil {
return
}
doneApiItemSet.Add(subItem.String())
}
}
return
}
func (c *controllerGenerator) getSubItemsByModuleAndVersion(items []apiItem, module, version string) (subItems []apiItem) {
for _, item := range items {
if item.Module == module && item.Version == version {
subItems = append(subItems, item)
}
}
return
}
func (c *controllerGenerator) doGenerateCtrlNewByModuleAndVersion(
dstModuleFolderPath, module, version, importPath string,
) (err error) {
var (
moduleFilePath = gfile.Join(dstModuleFolderPath, module+".go")
moduleFilePathNew = gfile.Join(dstModuleFolderPath, module+"_new.go")
ctrlName = fmt.Sprintf(`Controller%s`, gstr.UcFirst(version))
interfaceName = fmt.Sprintf(`%s.I%s%s`, module, gstr.CaseCamel(module), gstr.UcFirst(version))
newFuncName = fmt.Sprintf(`New%s`, gstr.UcFirst(version))
newFuncNameDefinition = fmt.Sprintf(`func %s()`, newFuncName)
alreadyCreated bool
)
// replace "\" to "/", fix import error
importPath = gstr.Replace(importPath, "\\", "/", -1)
if !gfile.Exists(moduleFilePath) {
content := gstr.ReplaceByMap(consts.TemplateGenCtrlControllerEmpty, g.MapStrStr{
"{Module}": module,
})
if err = gfile.PutContents(moduleFilePath, gstr.TrimLeft(content)); err != nil {
return err
}
mlog.Printf(`generated: %s`, moduleFilePath)
}
if !gfile.Exists(moduleFilePathNew) {
content := gstr.ReplaceByMap(consts.TemplateGenCtrlControllerNewEmpty, g.MapStrStr{
"{Module}": module,
"{ImportPath}": fmt.Sprintf(`"%s"`, importPath),
})
if err = gfile.PutContents(moduleFilePathNew, gstr.TrimLeft(content)); err != nil {
return err
}
mlog.Printf(`generated: %s`, moduleFilePathNew)
}
filePaths, err := gfile.ScanDir(dstModuleFolderPath, "*.go", false)
if err != nil {
return err
}
for _, filePath := range filePaths {
if gstr.Contains(gfile.GetContents(filePath), newFuncNameDefinition) {
alreadyCreated = true
break
}
}
if !alreadyCreated {
content := gstr.ReplaceByMap(consts.TemplateGenCtrlControllerNewFunc, g.MapStrStr{
"{CtrlName}": ctrlName,
"{NewFuncName}": newFuncName,
"{InterfaceName}": interfaceName,
})
err = gfile.PutContentsAppend(moduleFilePathNew, gstr.TrimLeft(content))
if err != nil {
return err
}
}
return
}
func (c *controllerGenerator) doGenerateCtrlItem(dstModuleFolderPath string, item apiItem) (err error) {
var (
methodNameSnake = gstr.CaseSnake(item.MethodName)
ctrlName = fmt.Sprintf(`Controller%s`, gstr.UcFirst(item.Version))
methodFilePath = gfile.Join(dstModuleFolderPath, fmt.Sprintf(
`%s_%s_%s.go`, item.Module, item.Version, methodNameSnake,
))
)
content := gstr.ReplaceByMap(consts.TemplateGenCtrlControllerMethodFunc, g.MapStrStr{
"{Module}": item.Module,
"{ImportPath}": item.Import,
"{CtrlName}": ctrlName,
"{Version}": item.Version,
"{MethodName}": item.MethodName,
})
if err = gfile.PutContents(methodFilePath, gstr.TrimLeft(content)); err != nil {
return err
}
mlog.Printf(`generated: %s`, methodFilePath)
return
}

View File

@@ -0,0 +1,57 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genctrl
import (
"fmt"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
type controllerClearer struct{}
func newControllerClearer() *controllerClearer {
return &controllerClearer{}
}
func (c *controllerClearer) Clear(dstModuleFolderPath string, extraApiItemsInCtrl []apiItem) (err error) {
for _, item := range extraApiItemsInCtrl {
if err = c.doClear(dstModuleFolderPath, item); err != nil {
return err
}
}
return
}
func (c *controllerClearer) doClear(dstModuleFolderPath string, item apiItem) (err error) {
var (
methodNameSnake = gstr.CaseSnake(item.MethodName)
methodFilePath = gfile.Join(dstModuleFolderPath, fmt.Sprintf(
`%s_%s_%s.go`, item.Module, item.Version, methodNameSnake,
))
fileContent = gstr.Trim(gfile.GetContents(methodFilePath))
)
match, err := gregex.MatchString(`.+?Req.+?Res.+?{([\s\S]+?)}`, fileContent)
if err != nil {
return err
}
if len(match) > 1 {
implements := gstr.Trim(match[1])
// One line.
if !gstr.Contains(implements, "\n") && gstr.Contains(implements, `CodeNotImplemented`) {
mlog.Printf(
`remove unimplemented and of no api definitions controller file: %s`,
methodFilePath,
)
err = gfile.Remove(methodFilePath)
}
}
return
}

View File

@@ -0,0 +1,111 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genctrl
import (
"fmt"
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/library/hggen/internal/consts"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
type apiInterfaceGenerator struct{}
func newApiInterfaceGenerator() *apiInterfaceGenerator {
return &apiInterfaceGenerator{}
}
func (c *apiInterfaceGenerator) Generate(apiModuleFolderPath string, apiModuleApiItems []apiItem) (err error) {
if len(apiModuleApiItems) == 0 {
return nil
}
var firstApiItem = apiModuleApiItems[0]
if err = c.doGenerate(apiModuleFolderPath, firstApiItem.Module, apiModuleApiItems); err != nil {
return
}
return
}
func (c *apiInterfaceGenerator) doGenerate(apiModuleFolderPath string, module string, items []apiItem) (err error) {
var (
moduleFilePath = gfile.Join(apiModuleFolderPath, fmt.Sprintf(`%s.go`, module))
importPathMap = gmap.NewListMap()
importPaths []string
)
// all import paths.
importPathMap.Set("\t"+`"context"`, 1)
importPathMap.Set("\t"+``, 1)
for _, item := range items {
importPathMap.Set(fmt.Sprintf("\t"+`"%s"`, item.Import), 1)
}
importPaths = gconv.Strings(importPathMap.Keys())
// interface definitions.
var (
doneApiItemSet = gset.NewStrSet()
interfaceDefinition string
interfaceContent = gstr.TrimLeft(gstr.ReplaceByMap(consts.TemplateGenCtrlApiInterface, g.MapStrStr{
"{Module}": module,
"{ImportPaths}": gstr.Join(importPaths, "\n"),
}))
)
for _, item := range items {
if doneApiItemSet.Contains(item.String()) {
continue
}
// retrieve all api items of the same module.
subItems := c.getSubItemsByModuleAndVersion(items, item.Module, item.Version)
var (
method string
methods = make([]string, 0)
interfaceName = fmt.Sprintf(`I%s%s`, gstr.CaseCamel(item.Module), gstr.UcFirst(item.Version))
)
for _, subItem := range subItems {
method = fmt.Sprintf(
"\t%s(ctx context.Context, req *%s.%sReq) (res *%s.%sRes, err error)",
subItem.MethodName, subItem.Version, subItem.MethodName, subItem.Version, subItem.MethodName,
)
methods = append(methods, method)
doneApiItemSet.Add(subItem.String())
}
interfaceDefinition += fmt.Sprintf("type %s interface {", interfaceName)
interfaceDefinition += "\n"
interfaceDefinition += gstr.Join(methods, "\n")
interfaceDefinition += "\n"
interfaceDefinition += fmt.Sprintf("}")
interfaceDefinition += "\n\n"
}
interfaceContent = gstr.TrimLeft(gstr.ReplaceByMap(interfaceContent, g.MapStrStr{
"{Interfaces}": interfaceDefinition,
}))
err = gfile.PutContents(moduleFilePath, interfaceContent)
mlog.Printf(`generated: %s`, moduleFilePath)
return
}
func (c *apiInterfaceGenerator) getSubItemsByModule(items []apiItem, module string) (subItems []apiItem) {
for _, item := range items {
if item.Module == module {
subItems = append(subItems, item)
}
}
return
}
func (c *apiInterfaceGenerator) getSubItemsByModuleAndVersion(items []apiItem, module, version string) (subItems []apiItem) {
for _, item := range items {
if item.Module == module && item.Version == version {
subItems = append(subItems, item)
}
}
return
}

View File

@@ -0,0 +1,194 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genctrl
import (
"fmt"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"hotgo/internal/library/hggen/internal/consts"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
type apiSdkGenerator struct{}
func newApiSdkGenerator() *apiSdkGenerator {
return &apiSdkGenerator{}
}
func (c *apiSdkGenerator) Generate(apiModuleApiItems []apiItem, sdkFolderPath string, sdkStdVersion, sdkNoV1 bool) (err error) {
if err = c.doGenerateSdkPkgFile(sdkFolderPath); err != nil {
return
}
var doneApiItemSet = gset.NewStrSet()
for _, item := range apiModuleApiItems {
if doneApiItemSet.Contains(item.String()) {
continue
}
// retrieve all api items of the same module.
subItems := c.getSubItemsByModuleAndVersion(apiModuleApiItems, item.Module, item.Version)
if err = c.doGenerateSdkIClient(sdkFolderPath, item.Import, item.Module, item.Version, sdkNoV1); err != nil {
return
}
if err = c.doGenerateSdkImplementer(
subItems, sdkFolderPath, item.Import, item.Module, item.Version, sdkStdVersion, sdkNoV1,
); err != nil {
return
}
for _, subItem := range subItems {
doneApiItemSet.Add(subItem.String())
}
}
return
}
func (c *apiSdkGenerator) doGenerateSdkPkgFile(sdkFolderPath string) (err error) {
var (
pkgName = gfile.Basename(sdkFolderPath)
pkgFilePath = gfile.Join(sdkFolderPath, fmt.Sprintf(`%s.go`, pkgName))
fileContent string
)
if gfile.Exists(pkgFilePath) {
return nil
}
fileContent = gstr.TrimLeft(gstr.ReplaceByMap(consts.TemplateGenCtrlSdkPkgNew, g.MapStrStr{
"{PkgName}": pkgName,
}))
err = gfile.PutContents(pkgFilePath, fileContent)
mlog.Printf(`generated: %s`, pkgFilePath)
return
}
func (c *apiSdkGenerator) doGenerateSdkIClient(
sdkFolderPath, versionImportPath, module, version string, sdkNoV1 bool,
) (err error) {
var (
fileContent string
isDirty bool
isExist bool
pkgName = gfile.Basename(sdkFolderPath)
funcName = gstr.CaseCamel(module) + gstr.UcFirst(version)
interfaceName = fmt.Sprintf(`I%s`, funcName)
moduleImportPath = fmt.Sprintf(`"%s"`, gfile.Dir(versionImportPath))
iClientFilePath = gfile.Join(sdkFolderPath, fmt.Sprintf(`%s.iclient.go`, pkgName))
interfaceFuncDefinition = fmt.Sprintf(
`%s() %s.%s`,
gstr.CaseCamel(module)+gstr.UcFirst(version), module, interfaceName,
)
)
if sdkNoV1 && version == "v1" {
interfaceFuncDefinition = fmt.Sprintf(
`%s() %s.%s`,
gstr.CaseCamel(module), module, interfaceName,
)
}
if isExist = gfile.Exists(iClientFilePath); isExist {
fileContent = gfile.GetContents(iClientFilePath)
} else {
fileContent = gstr.TrimLeft(gstr.ReplaceByMap(consts.TemplateGenCtrlSdkIClient, g.MapStrStr{
"{PkgName}": pkgName,
}))
}
// append the import path to current import paths.
if !gstr.Contains(fileContent, moduleImportPath) {
isDirty = true
fileContent, err = gregex.ReplaceString(
`(import \([\s\S]*?)\)`,
fmt.Sprintf("$1\t%s\n)", moduleImportPath),
fileContent,
)
if err != nil {
return
}
}
// append the function definition to interface definition.
if !gstr.Contains(fileContent, interfaceFuncDefinition) {
isDirty = true
fileContent, err = gregex.ReplaceString(
`(type iClient interface {[\s\S]*?)}`,
fmt.Sprintf("$1\t%s\n}", interfaceFuncDefinition),
fileContent,
)
if err != nil {
return
}
}
if isDirty {
err = gfile.PutContents(iClientFilePath, fileContent)
if isExist {
mlog.Printf(`updated: %s`, iClientFilePath)
} else {
mlog.Printf(`generated: %s`, iClientFilePath)
}
}
return
}
func (c *apiSdkGenerator) doGenerateSdkImplementer(
items []apiItem, sdkFolderPath, versionImportPath, module, version string, sdkStdVersion, sdkNoV1 bool,
) (err error) {
var (
pkgName = gfile.Basename(sdkFolderPath)
moduleNameCamel = gstr.CaseCamel(module)
moduleNameSnake = gstr.CaseSnake(module)
moduleImportPath = gfile.Dir(versionImportPath)
versionPrefix = ""
implementerName = moduleNameCamel + gstr.UcFirst(version)
implementerFilePath = gfile.Join(sdkFolderPath, fmt.Sprintf(
`%s_%s_%s.go`, pkgName, moduleNameSnake, version,
))
)
if sdkNoV1 && version == "v1" {
implementerName = moduleNameCamel
}
// implementer file template.
var importPaths = make([]string, 0)
importPaths = append(importPaths, fmt.Sprintf("\t\"%s\"", moduleImportPath))
importPaths = append(importPaths, fmt.Sprintf("\t\"%s\"", versionImportPath))
implementerFileContent := gstr.TrimLeft(gstr.ReplaceByMap(consts.TemplateGenCtrlSdkImplementer, g.MapStrStr{
"{PkgName}": pkgName,
"{ImportPaths}": gstr.Join(importPaths, "\n"),
"{ImplementerName}": implementerName,
}))
// implementer new function definition.
if sdkStdVersion {
versionPrefix = fmt.Sprintf(`/api/%s`, version)
}
implementerFileContent += gstr.TrimLeft(gstr.ReplaceByMap(consts.TemplateGenCtrlSdkImplementerNew, g.MapStrStr{
"{Module}": module,
"{VersionPrefix}": versionPrefix,
"{ImplementerName}": implementerName,
}))
// implementer functions definitions.
for _, item := range items {
implementerFileContent += gstr.TrimLeft(gstr.ReplaceByMap(consts.TemplateGenCtrlSdkImplementerFunc, g.MapStrStr{
"{Version}": item.Version,
"{MethodName}": item.MethodName,
"{ImplementerName}": implementerName,
}))
implementerFileContent += "\n"
}
err = gfile.PutContents(implementerFilePath, implementerFileContent)
mlog.Printf(`generated: %s`, implementerFilePath)
return
}
func (c *apiSdkGenerator) getSubItemsByModuleAndVersion(items []apiItem, module, version string) (subItems []apiItem) {
for _, item := range items {
if item.Module == module && item.Version == version {
subItems = append(subItems, item)
}
}
return
}

View File

@@ -1,18 +1,26 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gendao
import (
"context"
"fmt"
"golang.org/x/mod/modfile"
"strings"
"github.com/gogf/gf/v2/container/garray"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gtag"
"hotgo/internal/library/hggen/internal/utility/utils"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
@@ -43,6 +51,12 @@ CONFIGURATION SUPPORT
path: "./my-app"
prefix: "primary_"
tables: "user, userDetail"
typeMapping:
decimal:
type: decimal.Decimal
import: github.com/shopspring/decimal
numeric:
type: string
`
CGenDaoBriefPath = `directory path for generated files`
CGenDaoBriefLink = `database configuration, the same as the ORM configuration of GoFrame`
@@ -64,6 +78,7 @@ CONFIGURATION SUPPORT
CGenDaoBriefNoJsonTag = `no json tag will be added for each field`
CGenDaoBriefNoModelComment = `no model comment will be added for each field`
CGenDaoBriefClear = `delete all generated go files that do not exist in database`
CGenDaoBriefTypeMapping = `custom local type mapping for generated struct attributes relevant to fields of table`
CGenDaoBriefGroup = `
specifying the configuration group name of database for generated ORM instance,
it's not necessary and the default value is "default"
@@ -99,7 +114,21 @@ generated json tag case for model struct, cases are as follows:
)
var (
createdAt = gtime.Now()
createdAt = gtime.Now()
defaultTypeMapping = map[string]TypeMapping{
"decimal": {
Type: "float64",
},
"money": {
Type: "float64",
},
"numeric": {
Type: "float64",
},
"smallmoney": {
Type: "float64",
},
}
)
func init() {
@@ -129,6 +158,7 @@ func init() {
`CGenDaoBriefNoJsonTag`: CGenDaoBriefNoJsonTag,
`CGenDaoBriefNoModelComment`: CGenDaoBriefNoModelComment,
`CGenDaoBriefClear`: CGenDaoBriefClear,
`CGenDaoBriefTypeMapping`: CGenDaoBriefTypeMapping,
`CGenDaoBriefGroup`: CGenDaoBriefGroup,
`CGenDaoBriefJsonCase`: CGenDaoBriefJsonCase,
`CGenDaoBriefTplDaoIndexPath`: CGenDaoBriefTplDaoIndexPath,
@@ -142,30 +172,31 @@ type (
CGenDao struct{}
CGenDaoInput struct {
g.Meta `name:"dao" config:"{CGenDaoConfig}" usage:"{CGenDaoUsage}" brief:"{CGenDaoBrief}" eg:"{CGenDaoEg}" ad:"{CGenDaoAd}"`
Path string `name:"path" short:"p" brief:"{CGenDaoBriefPath}" d:"internal"`
Link string `name:"link" short:"l" brief:"{CGenDaoBriefLink}"`
Tables string `name:"tables" short:"t" brief:"{CGenDaoBriefTables}"`
TablesEx string `name:"tablesEx" short:"x" brief:"{CGenDaoBriefTablesEx}"`
Group string `name:"group" short:"g" brief:"{CGenDaoBriefGroup}" d:"default"`
Prefix string `name:"prefix" short:"f" brief:"{CGenDaoBriefPrefix}"`
RemovePrefix string `name:"removePrefix" short:"r" brief:"{CGenDaoBriefRemovePrefix}"`
JsonCase string `name:"jsonCase" short:"j" brief:"{CGenDaoBriefJsonCase}" d:"CamelLower"`
ImportPrefix string `name:"importPrefix" short:"i" brief:"{CGenDaoBriefImportPrefix}"`
DaoPath string `name:"daoPath" short:"d" brief:"{CGenDaoBriefDaoPath}" d:"dao"`
DoPath string `name:"doPath" short:"o" brief:"{CGenDaoBriefDoPath}" d:"model/do"`
EntityPath string `name:"entityPath" short:"e" brief:"{CGenDaoBriefEntityPath}" d:"model/entity"`
TplDaoIndexPath string `name:"tplDaoIndexPath" short:"t1" brief:"{CGenDaoBriefTplDaoIndexPath}"`
TplDaoInternalPath string `name:"tplDaoInternalPath" short:"t2" brief:"{CGenDaoBriefTplDaoInternalPath}"`
TplDaoDoPath string `name:"tplDaoDoPath" short:"t3" brief:"{CGenDaoBriefTplDaoDoPathPath}"`
TplDaoEntityPath string `name:"tplDaoEntityPath" short:"t4" brief:"{CGenDaoBriefTplDaoEntityPath}"`
StdTime bool `name:"stdTime" short:"s" brief:"{CGenDaoBriefStdTime}" orphan:"true"`
WithTime bool `name:"withTime" short:"w" brief:"{CGenDaoBriefWithTime}" orphan:"true"`
GJsonSupport bool `name:"gJsonSupport" short:"n" brief:"{CGenDaoBriefGJsonSupport}" orphan:"true"`
OverwriteDao bool `name:"overwriteDao" short:"v" brief:"{CGenDaoBriefOverwriteDao}" orphan:"true"`
DescriptionTag bool `name:"descriptionTag" short:"c" brief:"{CGenDaoBriefDescriptionTag}" orphan:"true"`
NoJsonTag bool `name:"noJsonTag" short:"k" brief:"{CGenDaoBriefNoJsonTag}" orphan:"true"`
NoModelComment bool `name:"noModelComment" short:"m" brief:"{CGenDaoBriefNoModelComment}" orphan:"true"`
Clear bool `name:"clear" short:"a" brief:"{CGenDaoBriefClear}" orphan:"true"`
Path string `name:"path" short:"p" brief:"{CGenDaoBriefPath}" d:"internal"`
Link string `name:"link" short:"l" brief:"{CGenDaoBriefLink}"`
Tables string `name:"tables" short:"t" brief:"{CGenDaoBriefTables}"`
TablesEx string `name:"tablesEx" short:"x" brief:"{CGenDaoBriefTablesEx}"`
Group string `name:"group" short:"g" brief:"{CGenDaoBriefGroup}" d:"default"`
Prefix string `name:"prefix" short:"f" brief:"{CGenDaoBriefPrefix}"`
RemovePrefix string `name:"removePrefix" short:"r" brief:"{CGenDaoBriefRemovePrefix}"`
JsonCase string `name:"jsonCase" short:"j" brief:"{CGenDaoBriefJsonCase}" d:"CamelLower"`
ImportPrefix string `name:"importPrefix" short:"i" brief:"{CGenDaoBriefImportPrefix}"`
DaoPath string `name:"daoPath" short:"d" brief:"{CGenDaoBriefDaoPath}" d:"dao"`
DoPath string `name:"doPath" short:"o" brief:"{CGenDaoBriefDoPath}" d:"model/do"`
EntityPath string `name:"entityPath" short:"e" brief:"{CGenDaoBriefEntityPath}" d:"model/entity"`
TplDaoIndexPath string `name:"tplDaoIndexPath" short:"t1" brief:"{CGenDaoBriefTplDaoIndexPath}"`
TplDaoInternalPath string `name:"tplDaoInternalPath" short:"t2" brief:"{CGenDaoBriefTplDaoInternalPath}"`
TplDaoDoPath string `name:"tplDaoDoPath" short:"t3" brief:"{CGenDaoBriefTplDaoDoPathPath}"`
TplDaoEntityPath string `name:"tplDaoEntityPath" short:"t4" brief:"{CGenDaoBriefTplDaoEntityPath}"`
StdTime bool `name:"stdTime" short:"s" brief:"{CGenDaoBriefStdTime}" orphan:"true"`
WithTime bool `name:"withTime" short:"w" brief:"{CGenDaoBriefWithTime}" orphan:"true"`
GJsonSupport bool `name:"gJsonSupport" short:"n" brief:"{CGenDaoBriefGJsonSupport}" orphan:"true"`
OverwriteDao bool `name:"overwriteDao" short:"v" brief:"{CGenDaoBriefOverwriteDao}" orphan:"true"`
DescriptionTag bool `name:"descriptionTag" short:"c" brief:"{CGenDaoBriefDescriptionTag}" orphan:"true"`
NoJsonTag bool `name:"noJsonTag" short:"k" brief:"{CGenDaoBriefNoJsonTag}" orphan:"true"`
NoModelComment bool `name:"noModelComment" short:"m" brief:"{CGenDaoBriefNoModelComment}" orphan:"true"`
Clear bool `name:"clear" short:"a" brief:"{CGenDaoBriefClear}" orphan:"true"`
TypeMapping map[string]TypeMapping `name:"typeMapping" short:"y" brief:"{CGenDaoBriefTypeMapping}" orphan:"true"`
}
CGenDaoOutput struct{}
@@ -174,7 +205,11 @@ type (
DB gdb.DB
TableNames []string
NewTableNames []string
ModName string // Module name of current golang project, which is used for import purpose.
}
TypeMapping struct {
Type string `brief:"custom attribute type name"`
Import string `brief:"custom import for this type"`
}
)
@@ -195,16 +230,11 @@ func (c CGenDao) Dao(ctx context.Context, in CGenDaoInput) (out *CGenDaoOutput,
return
}
func DoGenDaoForArray(ctx context.Context, in CGenDaoInput) {
doGenDaoForArray(ctx, -1, in)
}
// doGenDaoForArray implements the "gen dao" command for configuration array.
func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) {
var (
err error
db gdb.DB
modName string // Go module name, eg: github.com/gogf/gf.
err error
db gdb.DB
)
if index >= 0 {
err = g.Cfg().MustGet(
@@ -219,20 +249,6 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) {
mlog.Fatalf(`path "%s" does not exist`, in.Path)
}
removePrefixArray := gstr.SplitAndTrim(in.RemovePrefix, ",")
if in.ImportPrefix == "" {
if !gfile.Exists("go.mod") {
mlog.Fatal("go.mod does not exist in current working directory")
}
var (
goModContent = gfile.GetContents("go.mod")
match, _ = gregex.MatchString(`^module\s+(.+)\s*`, goModContent)
)
if len(match) > 1 {
modName = gstr.Trim(match[1])
} else {
mlog.Fatal("module name does not found in go.mod")
}
}
// It uses user passed database configuration.
if in.Link != "" {
@@ -268,6 +284,17 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) {
tableNames = array.Slice()
}
// merge default typeMapping to input typeMapping.
if in.TypeMapping == nil {
in.TypeMapping = defaultTypeMapping
} else {
for key, typeMapping := range defaultTypeMapping {
if _, ok := in.TypeMapping[key]; !ok {
in.TypeMapping[key] = typeMapping
}
}
}
// Generating dao & model go files one by one according to given table name.
newTableNames := make([]string, len(tableNames))
for i, tableName := range tableNames {
@@ -278,13 +305,13 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) {
newTableName = in.Prefix + newTableName
newTableNames[i] = newTableName
}
// Dao: index and internal.
generateDao(ctx, CGenDaoInternalInput{
CGenDaoInput: in,
DB: db,
TableNames: tableNames,
NewTableNames: newTableNames,
ModName: modName,
})
// Do.
generateDo(ctx, CGenDaoInternalInput{
@@ -292,7 +319,6 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) {
DB: db,
TableNames: tableNames,
NewTableNames: newTableNames,
ModName: modName,
})
// Entity.
generateEntity(ctx, CGenDaoInternalInput{
@@ -300,15 +326,11 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) {
DB: db,
TableNames: tableNames,
NewTableNames: newTableNames,
ModName: modName,
})
}
func getImportPartContent(source string, isDo bool) string {
var (
packageImportsArray = garray.NewStrArray()
)
func getImportPartContent(ctx context.Context, source string, isDo bool, appendImports []string) string {
var packageImportsArray = garray.NewStrArray()
if isDo {
packageImportsArray.Append(`"github.com/gogf/gf/v2/frame/g"`)
}
@@ -325,6 +347,32 @@ func getImportPartContent(source string, isDo bool) string {
packageImportsArray.Append(`"github.com/gogf/gf/v2/encoding/gjson"`)
}
// Check and update imports in go.mod
if appendImports != nil && len(appendImports) > 0 {
goModPath := utils.GetModPath()
if goModPath == "" {
mlog.Fatal("go.mod not found in current project")
}
mod, err := modfile.Parse(goModPath, gfile.GetBytes(goModPath), nil)
if err != nil {
mlog.Fatalf("parse go.mod failed: %+v", err)
}
for _, appendImport := range appendImports {
found := false
for _, require := range mod.Require {
if gstr.Contains(appendImport, require.Mod.Path) {
found = true
break
}
}
if !found {
err = gproc.ShellRun(ctx, `go get `+appendImport)
mlog.Fatalf(`%+v`, err)
}
packageImportsArray.Append(fmt.Sprintf(`"%s"`, appendImport))
}
}
// Generate and write content to golang file.
packageImportsStr := ""
if packageImportsArray.Len() > 0 {

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gendao
import (

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gendao
import (
@@ -6,12 +12,12 @@ import (
"fmt"
"strings"
"github.com/olekukonko/tablewriter"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/olekukonko/tablewriter"
"hotgo/internal/library/hggen/internal/consts"
"hotgo/internal/library/hggen/internal/utility/mlog"
@@ -53,23 +59,13 @@ func generateDaoSingle(ctx context.Context, in generateDaoSingleInput) {
mlog.Fatalf(`fetching tables fields failed for table "%s": %+v`, in.TableName, err)
}
var (
dirRealPath = gfile.RealPath(in.Path)
tableNameCamelCase = gstr.CaseCamel(in.NewTableName)
tableNameCamelLowerCase = gstr.CaseCamelLower(in.NewTableName)
tableNameSnakeCase = gstr.CaseSnake(in.NewTableName)
importPrefix = in.ImportPrefix
)
if importPrefix == "" {
if dirRealPath == "" {
dirRealPath = in.Path
importPrefix = dirRealPath
importPrefix = gstr.Trim(dirRealPath, "./")
} else {
importPrefix = gstr.Replace(dirRealPath, gfile.Pwd(), "")
}
importPrefix = gstr.Replace(importPrefix, gfile.Separator, "/")
importPrefix = gstr.Join(g.SliceStr{in.ModName, importPrefix, in.DaoPath}, "/")
importPrefix, _ = gregex.ReplaceString(`\/{2,}`, `/`, gstr.Trim(importPrefix, "/"))
importPrefix = utils.GetImportPath(gfile.Join(in.Path, in.DaoPath))
} else {
importPrefix = gstr.Join(g.SliceStr{importPrefix, in.DaoPath}, "/")
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gendao
import (
@@ -30,9 +36,9 @@ func generateDo(ctx context.Context, in CGenDaoInternalInput) {
mlog.Fatalf("fetching tables fields failed for table '%s':\n%v", tableName, err)
}
var (
newTableName = in.NewTableNames[i]
doFilePath = gfile.Join(dirPathDo, gstr.CaseSnake(newTableName)+".go")
structDefinition = generateStructDefinition(ctx, generateStructDefinitionInput{
newTableName = in.NewTableNames[i]
doFilePath = gfile.Join(dirPathDo, gstr.CaseSnake(newTableName)+".go")
structDefinition, _ = generateStructDefinition(ctx, generateStructDefinitionInput{
CGenDaoInternalInput: in,
TableName: tableName,
StructName: gstr.CaseCamel(newTableName),
@@ -53,6 +59,7 @@ func generateDo(ctx context.Context, in CGenDaoInternalInput) {
},
)
modelContent := generateDoContent(
ctx,
in,
tableName,
gstr.CaseCamel(newTableName),
@@ -68,15 +75,18 @@ func generateDo(ctx context.Context, in CGenDaoInternalInput) {
}
}
func generateDoContent(in CGenDaoInternalInput, tableName, tableNameCamelCase, structDefine string) string {
func generateDoContent(
ctx context.Context, in CGenDaoInternalInput, tableName, tableNameCamelCase, structDefine string,
) string {
doContent := gstr.ReplaceByMap(
getTemplateFromPathOrDefault(in.TplDaoDoPath, consts.TemplateGenDaoDoContent),
g.MapStrStr{
tplVarTableName: tableName,
tplVarPackageImports: getImportPartContent(structDefine, true),
tplVarPackageImports: getImportPartContent(ctx, structDefine, true, nil),
tplVarTableNameCamelCase: tableNameCamelCase,
tplVarStructDefine: structDefine,
})
},
)
doContent = replaceDefaultVar(in, doContent)
return doContent
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gendao
import (
@@ -24,22 +30,27 @@ func generateEntity(ctx context.Context, in CGenDaoInternalInput) {
if err != nil {
mlog.Fatalf("fetching tables fields failed for table '%s':\n%v", tableName, err)
}
var (
newTableName = in.NewTableNames[i]
entityFilePath = gfile.Join(dirPathEntity, gstr.CaseSnake(newTableName)+".go")
entityContent = generateEntityContent(
newTableName = in.NewTableNames[i]
entityFilePath = gfile.Join(dirPathEntity, gstr.CaseSnake(newTableName)+".go")
structDefinition, appendImports = generateStructDefinition(ctx, generateStructDefinitionInput{
CGenDaoInternalInput: in,
TableName: tableName,
StructName: gstr.CaseCamel(newTableName),
FieldMap: fieldMap,
IsDo: false,
})
entityContent = generateEntityContent(
ctx,
in,
newTableName,
gstr.CaseCamel(newTableName),
generateStructDefinition(ctx, generateStructDefinitionInput{
CGenDaoInternalInput: in,
TableName: tableName,
StructName: gstr.CaseCamel(newTableName),
FieldMap: fieldMap,
IsDo: false,
}),
structDefinition,
appendImports,
)
)
err = gfile.PutContents(entityFilePath, strings.TrimSpace(entityContent))
if err != nil {
mlog.Fatalf("writing content to '%s' failed: %v", entityFilePath, err)
@@ -50,15 +61,18 @@ func generateEntity(ctx context.Context, in CGenDaoInternalInput) {
}
}
func generateEntityContent(in CGenDaoInternalInput, tableName, tableNameCamelCase, structDefine string) string {
func generateEntityContent(
ctx context.Context, in CGenDaoInternalInput, tableName, tableNameCamelCase, structDefine string, appendImports []string,
) string {
entityContent := gstr.ReplaceByMap(
getTemplateFromPathOrDefault(in.TplDaoEntityPath, consts.TemplateGenDaoEntityContent),
g.MapStrStr{
tplVarTableName: tableName,
tplVarPackageImports: getImportPartContent(structDefine, false),
tplVarPackageImports: getImportPartContent(ctx, structDefine, false, appendImports),
tplVarTableNameCamelCase: tableNameCamelCase,
tplVarStructDefine: structDefine,
})
},
)
entityContent = replaceDefaultVar(in, entityContent)
return entityContent
}

View File

@@ -1,15 +1,22 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gendao
import (
"bytes"
"context"
"fmt"
"github.com/olekukonko/tablewriter"
"strings"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/olekukonko/tablewriter"
)
type generateStructDefinitionInput struct {
@@ -20,13 +27,18 @@ type generateStructDefinitionInput struct {
IsDo bool // Is generating DTO struct.
}
func generateStructDefinition(ctx context.Context, in generateStructDefinitionInput) string {
func generateStructDefinition(ctx context.Context, in generateStructDefinitionInput) (string, []string) {
var appendImports []string
buffer := bytes.NewBuffer(nil)
array := make([][]string, len(in.FieldMap))
names := sortFieldKeyForDao(in.FieldMap)
for index, name := range names {
var imports string
field := in.FieldMap[name]
array[index] = generateStructFieldDefinition(ctx, field, in)
array[index], imports = generateStructFieldDefinition(ctx, field, in)
if imports != "" {
appendImports = append(appendImports, imports)
}
}
tw := tablewriter.NewWriter(buffer)
tw.SetBorder(false)
@@ -47,21 +59,42 @@ func generateStructDefinition(ctx context.Context, in generateStructDefinitionIn
}
buffer.WriteString(stContent)
buffer.WriteString("}")
return buffer.String()
return buffer.String(), appendImports
}
// generateStructFieldDefinition generates and returns the attribute definition for specified field.
func generateStructFieldDefinition(
ctx context.Context, field *gdb.TableField, in generateStructDefinitionInput,
) []string {
) (attrLines []string, appendImport string) {
var (
err error
typeName string
jsonTag = getJsonTagFromCase(field.Name, in.JsonCase)
)
typeName, err = in.DB.CheckLocalTypeForField(ctx, field.Type, nil)
if err != nil {
panic(err)
if in.TypeMapping != nil && len(in.TypeMapping) > 0 {
var (
tryTypeName string
)
tryTypeMatch, _ := gregex.MatchString(`(.+?)\((.+)\)`, field.Type)
if len(tryTypeMatch) == 3 {
tryTypeName = gstr.Trim(tryTypeMatch[1])
} else {
tryTypeName = gstr.Split(field.Type, " ")[0]
}
if tryTypeName != "" {
if typeMapping, ok := in.TypeMapping[strings.ToLower(tryTypeName)]; ok {
typeName = typeMapping.Type
appendImport = typeMapping.Import
}
}
}
if typeName == "" {
typeName, err = in.DB.CheckLocalTypeForField(ctx, field.Type, nil)
if err != nil {
panic(err)
}
}
switch typeName {
case gdb.LocalTypeDate, gdb.LocalTypeDatetime:
@@ -87,19 +120,18 @@ func generateStructFieldDefinition(
}
var (
tagKey = "`"
result = []string{
" #" + gstr.CaseCamel(field.Name),
" #" + typeName,
}
tagKey = "`"
descriptionTag = gstr.Replace(formatComment(field.Comment), `"`, `\"`)
)
attrLines = []string{
" #" + gstr.CaseCamel(field.Name),
" #" + typeName,
}
attrLines = append(attrLines, " #"+fmt.Sprintf(tagKey+`json:"%s"`, jsonTag))
attrLines = append(attrLines, " #"+fmt.Sprintf(`description:"%s"`+tagKey, descriptionTag))
attrLines = append(attrLines, " #"+fmt.Sprintf(`// %s`, formatComment(field.Comment)))
result = append(result, " #"+fmt.Sprintf(tagKey+`json:"%s"`, jsonTag))
result = append(result, " #"+fmt.Sprintf(`description:"%s"`+tagKey, descriptionTag))
result = append(result, " #"+fmt.Sprintf(`// %s`, formatComment(field.Comment)))
for k, v := range result {
for k, v := range attrLines {
if in.NoJsonTag {
v, _ = gregex.ReplaceString(`json:".+"`, ``, v)
}
@@ -109,9 +141,9 @@ func generateStructFieldDefinition(
if in.NoModelComment {
v, _ = gregex.ReplaceString(`//.+`, ``, v)
}
result[k] = v
attrLines[k] = v
}
return result
return attrLines, appendImport
}
// formatComment formats the comment string to fit the golang code without any lines.

View File

@@ -21,10 +21,10 @@ import (
type (
CGenEnums struct{}
CGenEnumsInput struct {
g.Meta `name:"enums" config:"{CGenEnumsConfig}" brief:"{CGenEnumsBrief}" eg:"{CGenEnumsEg}"`
Src string `name:"src" short:"s" dc:"source folder path to be parsed" d:"."`
Path string `name:"path" short:"p" dc:"output go file path storing enums content" d:"internal/boot/boot_enums.go"`
Prefix string `name:"prefix" short:"x" dc:"only exports packages that starts with specified prefix"`
g.Meta `name:"enums" config:"{CGenEnumsConfig}" brief:"{CGenEnumsBrief}" eg:"{CGenEnumsEg}"`
Src string `name:"src" short:"s" dc:"source folder path to be parsed" d:"."`
Path string `name:"path" short:"p" dc:"output go file path storing enums content" d:"internal/boot/boot_enums.go"`
Prefixes []string `name:"prefixes" short:"x" dc:"only exports packages that starts with specified prefixes"`
}
CGenEnumsOutput struct{}
)
@@ -57,7 +57,7 @@ func (c CGenEnums) Enums(ctx context.Context, in CGenEnumsInput) (out *CGenEnums
if err != nil {
mlog.Fatal(err)
}
mlog.Printf(`scanning: %s`, realPath)
mlog.Printf(`scanning for enums: %s`, realPath)
cfg := &packages.Config{
Dir: realPath,
Mode: pkgLoadMode,
@@ -67,7 +67,7 @@ func (c CGenEnums) Enums(ctx context.Context, in CGenEnumsInput) (out *CGenEnums
if err != nil {
mlog.Fatal(err)
}
p := NewEnumsParser(in.Prefix)
p := NewEnumsParser(in.Prefixes)
p.ParsePackages(pkgs)
var enumsContent = gstr.ReplaceByMap(consts.TemplateGenEnums, g.MapStrStr{
"{PackageName}": gfile.Basename(gfile.Dir(in.Path)),
@@ -77,7 +77,7 @@ func (c CGenEnums) Enums(ctx context.Context, in CGenEnumsInput) (out *CGenEnums
if err = gfile.PutContents(in.Path, enumsContent); err != nil {
return
}
mlog.Printf(`generated: %s`, in.Path)
mlog.Printf(`generated enums go file: %s`, in.Path)
mlog.Print("done!")
return
}

View File

@@ -21,7 +21,7 @@ const pkgLoadMode = 0xffffff
type EnumsParser struct {
enums []EnumItem
parsedPkg map[string]struct{}
prefix string
prefixes []string
}
type EnumItem struct {
@@ -33,21 +33,21 @@ type EnumItem struct {
var standardPackages = make(map[string]struct{})
//func init() {
// stdPackages, err := packages.Load(nil, "std")
// if err != nil {
// panic(err)
// }
// for _, p := range stdPackages {
// standardPackages[p.ID] = struct{}{}
// }
//}
func init() {
stdPackages, err := packages.Load(nil, "std")
if err != nil {
panic(err)
}
for _, p := range stdPackages {
standardPackages[p.ID] = struct{}{}
}
}
func NewEnumsParser(prefix string) *EnumsParser {
func NewEnumsParser(prefixes []string) *EnumsParser {
return &EnumsParser{
enums: make([]EnumItem, 0),
parsedPkg: make(map[string]struct{}),
prefix: prefix,
prefixes: prefixes,
}
}
@@ -67,9 +67,16 @@ func (p *EnumsParser) ParsePackage(pkg *packages.Package) {
return
}
p.parsedPkg[pkg.ID] = struct{}{}
// Only parse specified prefix.
if p.prefix != "" {
if !gstr.HasPrefix(pkg.ID, p.prefix) {
// Only parse specified prefixes.
if len(p.prefixes) > 0 {
var hasPrefix bool
for _, prefix := range p.prefixes {
if hasPrefix = gstr.HasPrefix(pkg.ID, prefix); hasPrefix {
break
}
}
if !hasPrefix {
return
}
}
@@ -108,7 +115,6 @@ func (p *EnumsParser) ParsePackage(pkg *packages.Package) {
Type: enumType,
Kind: enumKind,
})
}
for _, im := range pkg.Imports {
p.ParsePackage(im)

View File

@@ -0,0 +1,128 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genpb
import (
"context"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/util/gtag"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
type (
CGenPb struct{}
CGenPbInput struct {
g.Meta `name:"pb" config:"{CGenPbConfig}" brief:"{CGenPbBrief}" eg:"{CGenPbEg}"`
Path string `name:"path" short:"p" dc:"protobuf file folder path" d:"manifest/protobuf"`
OutputApi string `name:"api" short:"a" dc:"output folder path storing generated go files of api" d:"api"`
OutputCtrl string `name:"ctrl" short:"c" dc:"output folder path storing generated go files of controller" d:"internal/controller"`
}
CGenPbOutput struct{}
)
const (
CGenPbConfig = `gfcli.gen.pb`
CGenPbBrief = `parse proto files and generate protobuf go files`
CGenPbEg = `
gf gen pb
gf gen pb -p . -a . -p .
`
)
func init() {
gtag.Sets(g.MapStrStr{
`CGenPbEg`: CGenPbEg,
`CGenPbBrief`: CGenPbBrief,
`CGenPbConfig`: CGenPbConfig,
})
}
func (c CGenPb) Pb(ctx context.Context, in CGenPbInput) (out *CGenPbOutput, err error) {
// Necessary check.
protoc := gproc.SearchBinary("protoc")
if protoc == "" {
mlog.Fatalf(`command "protoc" not found in your environment, please install protoc first: https://grpc.io/docs/languages/go/quickstart/`)
}
// protocol fold checks.
var (
protoPath = gfile.RealPath(in.Path)
isParsingPWD bool
)
if protoPath == "" {
// Use current working directory as protoPath if there are proto files under.
currentPath := gfile.Pwd()
currentFiles, _ := gfile.ScanDirFile(currentPath, "*.proto")
if len(currentFiles) > 0 {
protoPath = currentPath
isParsingPWD = true
} else {
mlog.Fatalf(`proto files folder "%s" does not exist`, in.Path)
}
}
// output path checks.
outputApiPath := gfile.RealPath(in.OutputApi)
if outputApiPath == "" {
if isParsingPWD {
outputApiPath = protoPath
} else {
mlog.Fatalf(`output api folder "%s" does not exist`, in.OutputApi)
}
}
outputCtrlPath := gfile.RealPath(in.OutputCtrl)
if outputCtrlPath == "" {
if isParsingPWD {
outputCtrlPath = ""
} else {
mlog.Fatalf(`output controller folder "%s" does not exist`, in.OutputCtrl)
}
}
// folder scanning.
files, err := gfile.ScanDirFile(protoPath, "*.proto", true)
if err != nil {
mlog.Fatal(err)
}
if len(files) == 0 {
mlog.Fatalf(`no proto files found in folder "%s"`, in.Path)
}
if err = gfile.Chdir(protoPath); err != nil {
mlog.Fatal(err)
}
for _, file := range files {
var command = gproc.NewProcess(protoc, nil)
command.Args = append(command.Args, "--proto_path="+gfile.Pwd())
command.Args = append(command.Args, "--go_out=paths=source_relative:"+outputApiPath)
command.Args = append(command.Args, "--go-grpc_out=paths=source_relative:"+outputApiPath)
command.Args = append(command.Args, file)
mlog.Print(command.String())
if err = command.Run(ctx); err != nil {
mlog.Fatal(err)
}
}
// Generate struct tag according comment rules.
err = c.generateStructTag(ctx, generateStructTagInput{OutputApiPath: outputApiPath})
if err != nil {
return
}
// Generate controllers according comment rules.
if outputCtrlPath != "" {
err = c.generateController(ctx, generateControllerInput{
OutputApiPath: outputApiPath,
OutputCtrlPath: outputCtrlPath,
})
if err != nil {
return
}
}
mlog.Print("done!")
return
}

View File

@@ -0,0 +1,196 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genpb
import (
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"hotgo/internal/library/hggen/internal/utility/utils"
)
type generateControllerInput struct {
OutputApiPath string
OutputCtrlPath string
}
type generateCtrl struct {
Name string
Package string
Version string
Methods []generateCtrlMethod
}
type generateCtrlMethod struct {
Name string
Definition string
}
const (
controllerTemplate = `
package {Package}
type Controller struct {
{Version}.Unimplemented{Name}Server
}
func Register(s *grpcx.GrpcServer) {
{Version}.Register{Name}Server(s.Server, &Controller{})
}
`
controllerMethodTemplate = `
func (*Controller) {Definition} {
return nil, gerror.NewCode(gcode.CodeNotImplemented)
}
`
)
func (c CGenPb) generateController(ctx context.Context, in generateControllerInput) (err error) {
files, err := gfile.ScanDirFile(in.OutputApiPath, "*_grpc.pb.go", true)
if err != nil {
return err
}
var controllers []generateCtrl
for _, file := range files {
fileControllers, err := c.parseControllers(file)
if err != nil {
return err
}
controllers = append(controllers, fileControllers...)
}
if len(controllers) == 0 {
return nil
}
// Generate controller files.
err = c.doGenerateControllers(in, controllers)
return
}
func (c CGenPb) parseControllers(filePath string) ([]generateCtrl, error) {
var (
controllers []generateCtrl
content = gfile.GetContents(filePath)
)
_, err := gregex.ReplaceStringFuncMatch(
`type (\w+)Server interface {([\s\S]+?)}`,
content,
func(match []string) string {
ctrl := generateCtrl{
Name: match[1],
Package: strings.ReplaceAll(gfile.Basename(gfile.Dir(gfile.Dir(filePath))), "-", "_"),
Version: gfile.Basename(gfile.Dir(filePath)),
Methods: make([]generateCtrlMethod, 0),
}
lines := gstr.Split(match[2], "\n")
for _, line := range lines {
line = gstr.Trim(line)
if line == "" || !gstr.IsLetterUpper(line[0]) {
continue
}
// Comment.
if gregex.IsMatchString(`^//.+`, line) {
continue
}
line, _ = gregex.ReplaceStringFuncMatch(
`^(\w+)\(context\.Context, \*(\w+)\) \(\*(\w+), error\)$`,
line,
func(match []string) string {
return fmt.Sprintf(
`%s(ctx context.Context, req *%s.%s) (res *%s.%s, err error)`,
match[1], ctrl.Version, match[2], ctrl.Version, match[3],
)
},
)
ctrl.Methods = append(ctrl.Methods, generateCtrlMethod{
Name: gstr.Split(line, "(")[0],
Definition: line,
})
}
if len(ctrl.Methods) > 0 {
controllers = append(controllers, ctrl)
}
return match[0]
},
)
return controllers, err
}
func (c CGenPb) doGenerateControllers(in generateControllerInput, controllers []generateCtrl) (err error) {
for _, controller := range controllers {
err = c.doGenerateController(in, controller)
if err != nil {
return err
}
}
err = utils.ReplaceGeneratedContentGFV2(in.OutputCtrlPath)
return nil
}
func (c CGenPb) doGenerateController(in generateControllerInput, controller generateCtrl) (err error) {
var (
folderPath = gfile.Join(in.OutputCtrlPath, controller.Package)
filePath = gfile.Join(folderPath, controller.Package+".go")
isDirty bool
)
if !gfile.Exists(folderPath) {
if err = gfile.Mkdir(folderPath); err != nil {
return err
}
}
if !gfile.Exists(filePath) {
templateContent := gstr.ReplaceByMap(controllerTemplate, g.MapStrStr{
"{Name}": controller.Name,
"{Version}": controller.Version,
"{Package}": controller.Package,
})
if err = gfile.PutContents(filePath, templateContent); err != nil {
return err
}
isDirty = true
}
// Exist controller content.
var ctrlContent string
files, err := gfile.ScanDirFile(folderPath, "*.go", false)
if err != nil {
return err
}
for _, file := range files {
if ctrlContent != "" {
ctrlContent += "\n"
}
ctrlContent += gfile.GetContents(file)
}
// Generate method content.
var generatedContent string
for _, method := range controller.Methods {
if gstr.Contains(ctrlContent, fmt.Sprintf(`%s(`, method.Name)) {
continue
}
if generatedContent != "" {
generatedContent += "\n"
}
generatedContent += gstr.ReplaceByMap(controllerMethodTemplate, g.MapStrStr{
"{Definition}": method.Definition,
})
}
if generatedContent != "" {
err = gfile.PutContentsAppend(filePath, generatedContent)
if err != nil {
return err
}
isDirty = true
}
if isDirty {
utils.GoFmt(filePath)
}
return nil
}

View File

@@ -0,0 +1,113 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genpb
import (
"context"
"fmt"
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/library/hggen/internal/utility/utils"
)
type generateStructTagInput struct {
OutputApiPath string
}
func (c CGenPb) generateStructTag(ctx context.Context, in generateStructTagInput) (err error) {
files, err := gfile.ScanDirFile(in.OutputApiPath, "*.pb.go", true)
if err != nil {
return err
}
var content string
for _, file := range files {
content = gfile.GetContents(file)
content, err = c.doTagReplacement(ctx, content)
if err != nil {
return err
}
if err = gfile.PutContents(file, content); err != nil {
return err
}
utils.GoFmt(file)
}
return
}
func (c CGenPb) doTagReplacement(ctx context.Context, content string) (string, error) {
content, err := gregex.ReplaceStringFuncMatch(`type (\w+) struct {([\s\S]+?)}`, content, func(match []string) string {
var (
topCommentMatch []string
tailCommentMatch []string
lines = gstr.Split(match[2], "\n")
lineTagMap = gmap.NewListMap()
)
for index, line := range lines {
line = gstr.Trim(line)
if line == "" {
continue
}
// Top comment.
topCommentMatch, _ = gregex.MatchString(`^/[/|\*](.+)`, line)
if len(topCommentMatch) > 1 {
c.tagCommentIntoListMap(gstr.Trim(topCommentMatch[1]), lineTagMap)
continue
}
// Tail comment.
tailCommentMatch, _ = gregex.MatchString(".+?`.+?`.+?//(.+)", line)
if len(tailCommentMatch) > 1 {
c.tagCommentIntoListMap(gstr.Trim(tailCommentMatch[1]), lineTagMap)
}
// Tag injection.
if !lineTagMap.IsEmpty() {
tagContent := c.listMapToStructTag(lineTagMap)
lineTagMap.Clear()
line, _ = gregex.ReplaceString("`(.+)`", fmt.Sprintf("`$1 %s`", tagContent), line)
}
lines[index] = line
}
match[2] = gstr.Join(lines, "\n")
return fmt.Sprintf("type %s struct {%s}", match[1], match[2])
})
return content, err
}
func (c CGenPb) tagCommentIntoListMap(comment string, lineTagMap *gmap.ListMap) {
tagCommentMatch, _ := gregex.MatchString(`^(\w+):(.+)`, comment)
if len(tagCommentMatch) > 1 {
var (
tagName = gstr.Trim(tagCommentMatch[1])
tagContent = gstr.Trim(tagCommentMatch[2])
)
lineTagMap.Set(tagName, lineTagMap.GetVar(tagName).String()+tagContent)
} else {
var (
tagName = "dc"
tagContent = comment
)
lineTagMap.Set(tagName, lineTagMap.GetVar(tagName).String()+tagContent)
}
}
func (c CGenPb) listMapToStructTag(lineTagMap *gmap.ListMap) string {
var tag string
lineTagMap.Iterator(func(key, value interface{}) bool {
if tag != "" {
tag += " "
}
tag += fmt.Sprintf(
`%s:"%s"`,
key, gstr.Replace(gconv.String(value), `"`, `\"`),
)
return true
})
return tag
}

View File

@@ -0,0 +1,419 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genpbentity
import (
"bytes"
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gtag"
"github.com/olekukonko/tablewriter"
"hotgo/internal/library/hggen/internal/consts"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
type (
CGenPbEntity struct{}
CGenPbEntityInput struct {
g.Meta `name:"pbentity" config:"{CGenPbEntityConfig}" brief:"{CGenPbEntityBrief}" eg:"{CGenPbEntityEg}" ad:"{CGenPbEntityAd}"`
Path string `name:"path" short:"p" brief:"{CGenPbEntityBriefPath}" d:"manifest/protobuf/pbentity"`
Package string `name:"package" short:"k" brief:"{CGenPbEntityBriefPackage}"`
Link string `name:"link" short:"l" brief:"{CGenPbEntityBriefLink}"`
Tables string `name:"tables" short:"t" brief:"{CGenPbEntityBriefTables}"`
Prefix string `name:"prefix" short:"f" brief:"{CGenPbEntityBriefPrefix}"`
RemovePrefix string `name:"removePrefix" short:"r" brief:"{CGenPbEntityBriefRemovePrefix}"`
NameCase string `name:"nameCase" short:"n" brief:"{CGenPbEntityBriefNameCase}" d:"Camel"`
JsonCase string `name:"jsonCase" short:"j" brief:"{CGenPbEntityBriefJsonCase}" d:"CamelLower"`
Option string `name:"option" short:"o" brief:"{CGenPbEntityBriefOption}"`
}
CGenPbEntityOutput struct{}
CGenPbEntityInternalInput struct {
CGenPbEntityInput
DB gdb.DB
TableName string // TableName specifies the table name of the table.
NewTableName string // NewTableName specifies the prefix-stripped name of the table.
}
)
const (
defaultPackageSuffix = `api/pbentity`
CGenPbEntityConfig = `gfcli.gen.pbentity`
CGenPbEntityBrief = `generate entity message files in protobuf3 format`
CGenPbEntityEg = `
gf gen pbentity
gf gen pbentity -l "mysql:root:12345678@tcp(127.0.0.1:3306)/test"
gf gen pbentity -p ./protocol/demos/entity -t user,user_detail,user_login
gf gen pbentity -r user_ -k github.com/gogf/gf/example/protobuf
gf gen pbentity -r user_
`
CGenPbEntityAd = `
CONFIGURATION SUPPORT
Options are also supported by configuration file.
It's suggested using configuration file instead of command line arguments making producing.
The configuration node name is "gf.gen.pbentity", which also supports multiple databases, for example(config.yaml):
gfcli:
gen:
- pbentity:
link: "mysql:root:12345678@tcp(127.0.0.1:3306)/test"
path: "protocol/demos/entity"
tables: "order,products"
package: "demos"
- pbentity:
link: "mysql:root:12345678@tcp(127.0.0.1:3306)/primary"
path: "protocol/demos/entity"
prefix: "primary_"
tables: "user, userDetail"
package: "demos"
option: |
option go_package = "protobuf/demos";
option java_package = "protobuf/demos";
option php_namespace = "protobuf/demos";
`
CGenPbEntityBriefPath = `directory path for generated files storing`
CGenPbEntityBriefPackage = `package path for all entity proto files`
CGenPbEntityBriefLink = `database configuration, the same as the ORM configuration of GoFrame`
CGenPbEntityBriefTables = `generate models only for given tables, multiple table names separated with ','`
CGenPbEntityBriefPrefix = `add specified prefix for all entity names and entity proto files`
CGenPbEntityBriefRemovePrefix = `remove specified prefix of the table, multiple prefix separated with ','`
CGenPbEntityBriefOption = `extra protobuf options`
CGenPbEntityBriefGroup = `
specifying the configuration group name of database for generated ORM instance,
it's not necessary and the default value is "default"
`
CGenPbEntityBriefNameCase = `
case for message attribute names, default is "Camel":
| Case | Example |
|---------------- |--------------------|
| Camel | AnyKindOfString |
| CamelLower | anyKindOfString | default
| Snake | any_kind_of_string |
| SnakeScreaming | ANY_KIND_OF_STRING |
| SnakeFirstUpper | rgb_code_md5 |
| Kebab | any-kind-of-string |
| KebabScreaming | ANY-KIND-OF-STRING |
`
CGenPbEntityBriefJsonCase = `
case for message json tag, cases are the same as "nameCase", default "CamelLower".
set it to "none" to ignore json tag generating.
`
)
func init() {
gtag.Sets(g.MapStrStr{
`CGenPbEntityConfig`: CGenPbEntityConfig,
`CGenPbEntityBrief`: CGenPbEntityBrief,
`CGenPbEntityEg`: CGenPbEntityEg,
`CGenPbEntityAd`: CGenPbEntityAd,
`CGenPbEntityBriefPath`: CGenPbEntityBriefPath,
`CGenPbEntityBriefPackage`: CGenPbEntityBriefPackage,
`CGenPbEntityBriefLink`: CGenPbEntityBriefLink,
`CGenPbEntityBriefTables`: CGenPbEntityBriefTables,
`CGenPbEntityBriefPrefix`: CGenPbEntityBriefPrefix,
`CGenPbEntityBriefRemovePrefix`: CGenPbEntityBriefRemovePrefix,
`CGenPbEntityBriefGroup`: CGenPbEntityBriefGroup,
`CGenPbEntityBriefNameCase`: CGenPbEntityBriefNameCase,
`CGenPbEntityBriefJsonCase`: CGenPbEntityBriefJsonCase,
`CGenPbEntityBriefOption`: CGenPbEntityBriefOption,
})
}
func (c CGenPbEntity) PbEntity(ctx context.Context, in CGenPbEntityInput) (out *CGenPbEntityOutput, err error) {
var (
config = g.Cfg()
)
if config.Available(ctx) {
v := config.MustGet(ctx, CGenPbEntityConfig)
if v.IsSlice() {
for i := 0; i < len(v.Interfaces()); i++ {
doGenPbEntityForArray(ctx, i, in)
}
} else {
doGenPbEntityForArray(ctx, -1, in)
}
} else {
doGenPbEntityForArray(ctx, -1, in)
}
mlog.Print("done!")
return
}
func doGenPbEntityForArray(ctx context.Context, index int, in CGenPbEntityInput) {
var (
err error
db gdb.DB
)
if index >= 0 {
err = g.Cfg().MustGet(
ctx,
fmt.Sprintf(`%s.%d`, CGenPbEntityConfig, index),
).Scan(&in)
if err != nil {
mlog.Fatalf(`invalid configuration of "%s": %+v`, CGenPbEntityConfig, err)
}
}
if in.Package == "" {
mlog.Debug(`package parameter is empty, trying calculating the package path using go.mod`)
if !gfile.Exists("go.mod") {
mlog.Fatal("go.mod does not exist in current working directory")
}
var (
modName string
goModContent = gfile.GetContents("go.mod")
match, _ = gregex.MatchString(`^module\s+(.+)\s*`, goModContent)
)
if len(match) > 1 {
modName = gstr.Trim(match[1])
in.Package = modName + "/" + defaultPackageSuffix
} else {
mlog.Fatal("module name does not found in go.mod")
}
}
removePrefixArray := gstr.SplitAndTrim(in.RemovePrefix, ",")
// It uses user passed database configuration.
if in.Link != "" {
var (
tempGroup = gtime.TimestampNanoStr()
match, _ = gregex.MatchString(`([a-z]+):(.+)`, in.Link)
)
if len(match) == 3 {
gdb.AddConfigNode(tempGroup, gdb.ConfigNode{
Type: gstr.Trim(match[1]),
Link: gstr.Trim(match[2]),
})
db, _ = gdb.Instance(tempGroup)
}
} else {
db = g.DB()
}
if db == nil {
mlog.Fatal("database initialization failed")
}
tableNames := ([]string)(nil)
if in.Tables != "" {
tableNames = gstr.SplitAndTrim(in.Tables, ",")
} else {
tableNames, err = db.Tables(context.TODO())
if err != nil {
mlog.Fatalf("fetching tables failed: \n %v", err)
}
}
for _, tableName := range tableNames {
newTableName := tableName
for _, v := range removePrefixArray {
newTableName = gstr.TrimLeftStr(newTableName, v, 1)
}
generatePbEntityContentFile(ctx, CGenPbEntityInternalInput{
CGenPbEntityInput: in,
DB: db,
TableName: tableName,
NewTableName: newTableName,
})
}
}
// generatePbEntityContentFile generates the protobuf files for given table.
func generatePbEntityContentFile(ctx context.Context, in CGenPbEntityInternalInput) {
fieldMap, err := in.DB.TableFields(ctx, in.TableName)
if err != nil {
mlog.Fatalf("fetching tables fields failed for table '%s':\n%v", in.TableName, err)
}
// Change the `newTableName` if `Prefix` is given.
newTableName := in.Prefix + in.NewTableName
var (
imports string
tableNameCamelCase = gstr.CaseCamel(newTableName)
tableNameSnakeCase = gstr.CaseSnake(newTableName)
entityMessageDefine = generateEntityMessageDefinition(tableNameCamelCase, fieldMap, in)
fileName = gstr.Trim(tableNameSnakeCase, "-_.")
path = gfile.Join(in.Path, fileName+".proto")
)
if gstr.Contains(entityMessageDefine, "google.protobuf.Timestamp") {
imports = `import "google/protobuf/timestamp.proto";`
}
entityContent := gstr.ReplaceByMap(getTplPbEntityContent(""), g.MapStrStr{
"{Imports}": imports,
"{PackageName}": gfile.Basename(in.Package),
"{GoPackage}": in.Package,
"{OptionContent}": in.Option,
"{EntityMessage}": entityMessageDefine,
})
if err := gfile.PutContents(path, strings.TrimSpace(entityContent)); err != nil {
mlog.Fatalf("writing content to '%s' failed: %v", path, err)
} else {
mlog.Print("generated:", path)
}
}
// generateEntityMessageDefinition generates and returns the message definition for specified table.
func generateEntityMessageDefinition(entityName string, fieldMap map[string]*gdb.TableField, in CGenPbEntityInternalInput) string {
var (
buffer = bytes.NewBuffer(nil)
array = make([][]string, len(fieldMap))
names = sortFieldKeyForPbEntity(fieldMap)
)
for index, name := range names {
array[index] = generateMessageFieldForPbEntity(index+1, fieldMap[name], in)
}
tw := tablewriter.NewWriter(buffer)
tw.SetBorder(false)
tw.SetRowLine(false)
tw.SetAutoWrapText(false)
tw.SetColumnSeparator("")
tw.AppendBulk(array)
tw.Render()
stContent := buffer.String()
// Let's do this hack of table writer for indent!
stContent = gstr.Replace(stContent, " #", "")
buffer.Reset()
buffer.WriteString(fmt.Sprintf("message %s {\n", entityName))
buffer.WriteString(stContent)
buffer.WriteString("}")
return buffer.String()
}
// generateMessageFieldForPbEntity generates and returns the message definition for specified field.
func generateMessageFieldForPbEntity(index int, field *gdb.TableField, in CGenPbEntityInternalInput) []string {
var (
typeName string
comment string
jsonTagStr string
err error
ctx = gctx.GetInitCtx()
)
typeName, err = in.DB.CheckLocalTypeForField(ctx, field.Type, nil)
if err != nil {
panic(err)
}
var typeMapping = map[string]string{
gdb.LocalTypeString: "string",
gdb.LocalTypeDate: "google.protobuf.Timestamp",
gdb.LocalTypeDatetime: "google.protobuf.Timestamp",
gdb.LocalTypeInt: "int32",
gdb.LocalTypeUint: "uint32",
gdb.LocalTypeInt64: "int64",
gdb.LocalTypeUint64: "uint64",
gdb.LocalTypeIntSlice: "repeated int32",
gdb.LocalTypeInt64Slice: "repeated int64",
gdb.LocalTypeUint64Slice: "repeated uint64",
gdb.LocalTypeInt64Bytes: "repeated int64",
gdb.LocalTypeUint64Bytes: "repeated uint64",
gdb.LocalTypeFloat32: "float",
gdb.LocalTypeFloat64: "double",
gdb.LocalTypeBytes: "bytes",
gdb.LocalTypeBool: "bool",
gdb.LocalTypeJson: "string",
gdb.LocalTypeJsonb: "string",
}
typeName = typeMapping[typeName]
if typeName == "" {
typeName = "string"
}
comment = gstr.ReplaceByArray(field.Comment, g.SliceStr{
"\n", " ",
"\r", " ",
})
comment = gstr.Trim(comment)
comment = gstr.Replace(comment, `\n`, " ")
comment, _ = gregex.ReplaceString(`\s{2,}`, ` `, comment)
if jsonTagName := formatCase(field.Name, in.JsonCase); jsonTagName != "" {
// beautiful indent.
if index < 10 {
// 3 spaces
jsonTagStr = " " + jsonTagStr
} else if index < 100 {
// 2 spaces
jsonTagStr = " " + jsonTagStr
} else {
// 1 spaces
jsonTagStr = " " + jsonTagStr
}
}
return []string{
" #" + typeName,
" #" + formatCase(field.Name, in.NameCase),
" #= " + gconv.String(index) + jsonTagStr + ";",
" #" + fmt.Sprintf(`// %s`, comment),
}
}
func getTplPbEntityContent(tplEntityPath string) string {
if tplEntityPath != "" {
return gfile.GetContents(tplEntityPath)
}
return consts.TemplatePbEntityMessageContent
}
// formatCase call gstr.Case* function to convert the s to specified case.
func formatCase(str, caseStr string) string {
switch gstr.ToLower(caseStr) {
case gstr.ToLower("Camel"):
return gstr.CaseCamel(str)
case gstr.ToLower("CamelLower"):
return gstr.CaseCamelLower(str)
case gstr.ToLower("Kebab"):
return gstr.CaseKebab(str)
case gstr.ToLower("KebabScreaming"):
return gstr.CaseKebabScreaming(str)
case gstr.ToLower("Snake"):
return gstr.CaseSnake(str)
case gstr.ToLower("SnakeFirstUpper"):
return gstr.CaseSnakeFirstUpper(str)
case gstr.ToLower("SnakeScreaming"):
return gstr.CaseSnakeScreaming(str)
case "none":
return ""
}
return str
}
func sortFieldKeyForPbEntity(fieldMap map[string]*gdb.TableField) []string {
names := make(map[int]string)
for _, field := range fieldMap {
names[field.Index] = field.Name
}
var (
result = make([]string, len(names))
i = 0
j = 0
)
for {
if len(names) == 0 {
break
}
if val, ok := names[i]; ok {
result[j] = val
j++
delete(names, i)
}
i++
}
return result
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genservice
import (
@@ -5,10 +11,10 @@ import (
"fmt"
"github.com/gogf/gf/v2/container/garray"
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
@@ -44,7 +50,7 @@ destination file name storing automatically generated go files, cases are as fol
`
CGenServiceBriefWatchFile = `used in file watcher, it re-generates all service go files only if given file is under srcFolder`
CGenServiceBriefStPattern = `regular expression matching struct name for generating service. default: ^s([A-Z]\\\\w+)$`
CGenServiceBriefPackages = `produce go files only for given source packages`
CGenServiceBriefPackages = `produce go files only for given source packages(source folders)`
CGenServiceBriefImportPrefix = `custom import prefix to calculate import path for generated importing go file of logic`
CGenServiceBriefClear = `delete all generated go files that are not used any further`
)
@@ -87,21 +93,6 @@ const (
)
func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGenServiceOutput, err error) {
// File lock to avoid multiple processes.
var (
flockFilePath = gfile.Temp("gf.cli.gen.service.lock")
flockContent = gfile.GetContents(flockFilePath)
)
if flockContent != "" {
if gtime.Timestamp()-gconv.Int64(flockContent) < genServiceFileLockSeconds {
// If another "gen service" process is running, it just exits.
mlog.Debug(`another "gen service" process is running, exit`)
return
}
}
defer gfile.Remove(flockFilePath)
_ = gfile.PutContents(flockFilePath, gtime.TimestampStr())
in.SrcFolder = gstr.TrimRight(in.SrcFolder, `\/`)
in.SrcFolder = gstr.Replace(in.SrcFolder, "\\", "/")
in.WatchFile = gstr.TrimRight(in.WatchFile, `\/`)
@@ -109,6 +100,21 @@ func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGe
// Watch file handling.
if in.WatchFile != "" {
// File lock to avoid multiple processes.
var (
flockFilePath = gfile.Temp("gf.cli.gen.service.lock")
flockContent = gfile.GetContents(flockFilePath)
)
if flockContent != "" {
if gtime.Timestamp()-gconv.Int64(flockContent) < genServiceFileLockSeconds {
// If another "gen service" process is running, it just exits.
mlog.Debug(`another "gen service" process is running, exit`)
return
}
}
defer gfile.Remove(flockFilePath)
_ = gfile.PutContents(flockFilePath, gtime.TimestampStr())
// It works only if given WatchFile is in SrcFolder.
var (
watchFileDir = gfile.Dir(in.WatchFile)
@@ -125,13 +131,10 @@ func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGe
mlog.Fatalf(`%+v`, err)
}
mlog.Debug("Chdir:", newWorkingDir)
_ = gfile.Remove(flockFilePath)
var command = fmt.Sprintf(
`%s gen service -packages=%s`,
gfile.SelfName(), gfile.Basename(watchFileDir),
)
err = gproc.ShellRun(ctx, command)
return
in.WatchFile = ""
in.Packages = []string{gfile.Basename(watchFileDir)}
return c.Service(ctx, in)
}
if !gfile.Exists(in.SrcFolder) {
@@ -139,16 +142,7 @@ func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGe
}
if in.ImportPrefix == "" {
if !gfile.Exists("go.mod") {
mlog.Fatal("ImportPrefix is empty and go.mod does not exist in current working directory")
}
var (
goModContent = gfile.GetContents("go.mod")
match, _ = gregex.MatchString(`^module\s+(.+)\s*`, goModContent)
)
if len(match) > 1 {
in.ImportPrefix = fmt.Sprintf(`%s/%s`, gstr.Trim(match[1]), gstr.Replace(in.SrcFolder, `\`, `/`))
}
in.ImportPrefix = utils.GetImportPath(in.SrcFolder)
}
var (
@@ -176,30 +170,118 @@ func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGe
if len(files) == 0 {
continue
}
// Parse single logic package folder.
var (
// StructName => FunctionDefinitions
srcPkgInterfaceMap = make(map[string]*garray.StrArray)
srcImportedPackages = garray.NewSortedStrArray().SetUnique(true)
srcPackageName = gfile.Basename(srcFolderPath)
ok bool
dstFilePath = gfile.Join(in.DstFolder,
srcPkgInterfaceMap = make(map[string]*garray.StrArray)
srcImportedPackages = garray.NewSortedStrArray().SetUnique(true)
importAliasToPathMap = gmap.NewStrStrMap() // for conflict imports check. alias => import path(with `"`)
importPathToAliasMap = gmap.NewStrStrMap() // for conflict imports check. import path(with `"`) => alias
srcPackageName = gfile.Basename(srcFolderPath)
ok bool
dstFilePath = gfile.Join(in.DstFolder,
c.getDstFileNameCase(srcPackageName, in.DstFileNameCase)+".go",
)
srcCodeCommentedMap = make(map[string]string)
)
generatedDstFilePathSet.Add(dstFilePath)
for _, file := range files {
var packageItems []packageItem
fileContent = gfile.GetContents(file)
// Calculate imported packages of source go files.
err = c.calculateImportedPackages(fileContent, srcImportedPackages)
// Calculate code comments in source Go files.
err = c.calculateCodeCommented(in, fileContent, srcCodeCommentedMap)
if err != nil {
return nil, err
}
// remove all comments.
fileContent, err = gregex.ReplaceString(`/[/|\*](.*)`, "", fileContent)
if err != nil {
return nil, err
}
// Calculate imported packages of source go files.
packageItems, err = c.calculateImportedPackages(fileContent)
if err != nil {
return nil, err
}
// try finding the conflicts imports between files.
for _, item := range packageItems {
var alias = item.Alias
if alias == "" {
alias = gfile.Basename(gstr.Trim(item.Path, `"`))
}
// ignore unused import paths, which do not exist in function definitions.
if !gregex.IsMatchString(fmt.Sprintf(`func .+?([^\w])%s(\.\w+).+?{`, alias), fileContent) {
mlog.Debugf(`ignore unused package: %s`, item.RawImport)
continue
}
// find the exist alias with the same import path.
var existAlias = importPathToAliasMap.Get(item.Path)
if existAlias != "" {
fileContent, err = gregex.ReplaceStringFuncMatch(
fmt.Sprintf(`([^\w])%s(\.\w+)`, alias), fileContent,
func(match []string) string {
return match[1] + existAlias + match[2]
},
)
if err != nil {
return nil, err
}
continue
}
// resolve alias conflicts.
var importPath = importAliasToPathMap.Get(alias)
if importPath == "" {
importAliasToPathMap.Set(alias, item.Path)
importPathToAliasMap.Set(item.Path, alias)
srcImportedPackages.Add(item.RawImport)
continue
}
if importPath != item.Path {
// update the conflicted alias for import path with suffix.
// eg:
// v1 -> v10
// v11 -> v110
for aliasIndex := 0; ; aliasIndex++ {
item.Alias = fmt.Sprintf(`%s%d`, alias, aliasIndex)
var existPathForAlias = importAliasToPathMap.Get(item.Alias)
if existPathForAlias != "" {
if existPathForAlias == item.Path {
break
}
continue
}
break
}
importPathToAliasMap.Set(item.Path, item.Alias)
importAliasToPathMap.Set(item.Alias, item.Path)
// reformat the import path with alias.
item.RawImport = fmt.Sprintf(`%s %s`, item.Alias, item.Path)
// update the file content with new alias import.
fileContent, err = gregex.ReplaceStringFuncMatch(
fmt.Sprintf(`([^\w])%s(\.\w+)`, alias), fileContent,
func(match []string) string {
return match[1] + item.Alias + match[2]
},
)
if err != nil {
return nil, err
}
srcImportedPackages.Add(item.RawImport)
}
}
// Calculate functions and interfaces for service generating.
err = c.calculateInterfaceFunctions(in, fileContent, srcPkgInterfaceMap, dstPackageName)
err = c.calculateInterfaceFunctions(in, fileContent, srcPkgInterfaceMap)
if err != nil {
return nil, err
}
}
initImportSrcPackages = append(
initImportSrcPackages,
fmt.Sprintf(`%s/%s`, in.ImportPrefix, srcPackageName),
@@ -212,7 +294,7 @@ func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGe
)
continue
}
// Generating service go file for logic.
// Generating service go file for single logic package.
if ok, err = c.generateServiceFile(generateServiceFilesInput{
CGenServiceInput: in,
SrcStructFunctions: srcPkgInterfaceMap,
@@ -220,6 +302,7 @@ func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGe
SrcPackageName: srcPackageName,
DstPackageName: dstPackageName,
DstFilePath: dstFilePath,
SrcCodeCommentedMap: srcCodeCommentedMap,
}); err != nil {
return
}
@@ -254,24 +337,71 @@ func (c CGenService) Service(ctx context.Context, in CGenServiceInput) (out *CGe
}
// Replace v1 to v2 for GoFrame.
if err = c.replaceGeneratedServiceContentGFV2(in); err != nil {
if err = utils.ReplaceGeneratedContentGFV2(in.DstFolder); err != nil {
return nil, err
}
mlog.Printf(`gofmt go files in "%s"`, in.DstFolder)
utils.GoFmt(in.DstFolder)
}
// auto update main.go.
if err = c.checkAndUpdateMain(in.SrcFolder); err != nil {
return nil, err
}
mlog.Print(`done!`)
return
}
func (c CGenService) replaceGeneratedServiceContentGFV2(in CGenServiceInput) (err error) {
return gfile.ReplaceDirFunc(func(path, content string) string {
if gstr.Contains(content, `"github.com/gogf/gf`) && !gstr.Contains(content, `"github.com/gogf/gf/v2`) {
content = gstr.Replace(content, `"github.com/gogf/gf"`, `"github.com/gogf/gf/v2"`)
content = gstr.Replace(content, `"github.com/gogf/gf/`, `"github.com/gogf/gf/v2/`)
return content
func (c CGenService) checkAndUpdateMain(srcFolder string) (err error) {
var (
logicPackageName = gstr.ToLower(gfile.Basename(srcFolder))
logicFilePath = gfile.Join(srcFolder, logicPackageName+".go")
importPath = utils.GetImportPath(logicFilePath)
importStr = fmt.Sprintf(`_ "%s"`, importPath)
mainFilePath = gfile.Join(gfile.Dir(gfile.Dir(gfile.Dir(logicFilePath))), "main.go")
mainFileContent = gfile.GetContents(mainFilePath)
)
// No main content found.
if mainFileContent == "" {
return nil
}
if gstr.Contains(mainFileContent, importStr) {
return nil
}
match, err := gregex.MatchString(`import \(([\s\S]+?)\)`, mainFileContent)
if err != nil {
return err
}
// No match.
if len(match) < 2 {
return nil
}
lines := garray.NewStrArrayFrom(gstr.Split(match[1], "\n"))
for i, line := range lines.Slice() {
line = gstr.Trim(line)
if len(line) == 0 {
continue
}
return content
}, in.DstFolder, "*.go", false)
if line[0] == '_' {
continue
}
// Insert the logic import into imports.
if err = lines.InsertBefore(i, fmt.Sprintf("\t%s\n\n", importStr)); err != nil {
return err
}
break
}
mainFileContent, err = gregex.ReplaceString(
`import \(([\s\S]+?)\)`,
fmt.Sprintf(`import (%s)`, lines.Join("\n")),
mainFileContent,
)
if err != nil {
return err
}
mlog.Print(`update main.go`)
err = gfile.PutContents(mainFilePath, mainFileContent)
utils.GoFmt(mainFilePath)
return
}

View File

@@ -1,6 +1,13 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genservice
import (
"fmt"
"go/parser"
"go/token"
@@ -9,29 +16,90 @@ import (
"github.com/gogf/gf/v2/text/gstr"
)
func (c CGenService) calculateImportedPackages(fileContent string, srcImportedPackages *garray.SortedStrArray) (err error) {
type packageItem struct {
Alias string
Path string
RawImport string
}
func (c CGenService) calculateImportedPackages(fileContent string) (packages []packageItem, err error) {
f, err := parser.ParseFile(token.NewFileSet(), "", fileContent, parser.ImportsOnly)
if err != nil {
return err
return nil, err
}
packages = make([]packageItem, 0)
for _, s := range f.Imports {
if s.Path != nil {
if s.Name != nil {
// If it has alias, and it is not `_`.
if pkgAlias := s.Name.String(); pkgAlias != "_" {
srcImportedPackages.Add(pkgAlias + " " + s.Path.Value)
packages = append(packages, packageItem{
Alias: pkgAlias,
Path: s.Path.Value,
RawImport: pkgAlias + " " + s.Path.Value,
})
}
} else {
// no alias
srcImportedPackages.Add(s.Path.Value)
packages = append(packages, packageItem{
Alias: "",
Path: s.Path.Value,
RawImport: s.Path.Value,
})
}
}
}
return packages, nil
}
func (c CGenService) calculateCodeCommented(in CGenServiceInput, fileContent string, srcCodeCommentedMap map[string]string) error {
matches, err := gregex.MatchAllString(`((((//.*)|(/\*[\s\S]*?\*/))\s)+)func \((.+?)\) ([\s\S]+?) {`, fileContent)
if err != nil {
return err
}
for _, match := range matches {
var (
structName string
structMatch []string
funcReceiver = gstr.Trim(match[1+5])
receiverArray = gstr.SplitAndTrim(funcReceiver, " ")
functionHead = gstr.Trim(gstr.Replace(match[2+5], "\n", ""))
commentedInfo = ""
)
if len(receiverArray) > 1 {
structName = receiverArray[1]
} else if len(receiverArray) == 1 {
structName = receiverArray[0]
}
structName = gstr.Trim(structName, "*")
// Case of:
// Xxx(\n ctx context.Context, req *v1.XxxReq,\n) -> Xxx(ctx context.Context, req *v1.XxxReq)
functionHead = gstr.Replace(functionHead, `,)`, `)`)
functionHead, _ = gregex.ReplaceString(`\(\s+`, `(`, functionHead)
functionHead, _ = gregex.ReplaceString(`\s{2,}`, ` `, functionHead)
if !gstr.IsLetterUpper(functionHead[0]) {
continue
}
// Match and pick the struct name from receiver.
if structMatch, err = gregex.MatchString(in.StPattern, structName); err != nil {
return err
}
if len(structMatch) < 1 {
continue
}
structName = gstr.CaseCamel(structMatch[1])
commentedInfo = match[1]
if len(commentedInfo) > 0 {
srcCodeCommentedMap[fmt.Sprintf("%s-%s", structName, functionHead)] = commentedInfo
}
}
return nil
}
func (c CGenService) calculateInterfaceFunctions(
in CGenServiceInput, fileContent string, srcPkgInterfaceMap map[string]*garray.StrArray, dstPackageName string,
in CGenServiceInput, fileContent string, srcPkgInterfaceMap map[string]*garray.StrArray,
) (err error) {
var (
ok bool
@@ -53,7 +121,7 @@ func (c CGenService) calculateInterfaceFunctions(
)
if len(receiverArray) > 1 {
structName = receiverArray[1]
} else {
} else if len(receiverArray) == 1 {
structName = receiverArray[0]
}
structName = gstr.Trim(structName, "*")

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package genservice
import (
@@ -21,6 +27,7 @@ type generateServiceFilesInput struct {
SrcImportedPackages []string
SrcPackageName string
DstPackageName string
SrcCodeCommentedMap map[string]string
}
func (c CGenService) generateServiceFile(in generateServiceFilesInput) (ok bool, err error) {
@@ -41,6 +48,13 @@ func (c CGenService) generateServiceFile(in generateServiceFilesInput) (ok bool,
generatedContent += "\n"
for structName, funcArray := range in.SrcStructFunctions {
allFuncArray.Append(funcArray.Slice()...)
// Add comments to a method.
for index, funcName := range funcArray.Slice() {
if commentedInfo, exist := in.SrcCodeCommentedMap[fmt.Sprintf("%s-%s", structName, funcName)]; exist {
funcName = commentedInfo + funcName
_ = funcArray.Set(index, funcName)
}
}
generatedContent += gstr.Trim(gstr.ReplaceByMap(consts.TemplateGenServiceContentInterface, g.MapStrStr{
"{InterfaceName}": "I" + structName,
"{FuncDefinition}": funcArray.Join("\n\t"),
@@ -56,7 +70,7 @@ func (c CGenService) generateServiceFile(in generateServiceFilesInput) (ok bool,
generatingInterfaceCheck string
)
// Variable definitions.
for structName, _ := range in.SrcStructFunctions {
for structName := range in.SrcStructFunctions {
generatingInterfaceCheck = fmt.Sprintf(`[^\w\d]+%s.I%s[^\w\d]`, in.DstPackageName, structName)
if gregex.IsMatchString(generatingInterfaceCheck, generatedContent) {
continue
@@ -75,7 +89,7 @@ func (c CGenService) generateServiceFile(in generateServiceFilesInput) (ok bool,
generatedContent += "\n"
}
// Variable register function definitions.
for structName, _ := range in.SrcStructFunctions {
for structName := range in.SrcStructFunctions {
generatingInterfaceCheck = fmt.Sprintf(`[^\w\d]+%s.I%s[^\w\d]`, in.DstPackageName, structName)
if gregex.IsMatchString(generatingInterfaceCheck, generatedContent) {
continue
@@ -114,6 +128,7 @@ func (c CGenService) generateServiceFile(in generateServiceFilesInput) (ok bool,
// isToGenerateServiceGoFile checks and returns whether the service content dirty.
func (c CGenService) isToGenerateServiceGoFile(dstPackageName, filePath string, funcArray *garray.StrArray) bool {
var (
err error
fileContent = gfile.GetContents(filePath)
generatedFuncArray = garray.NewSortedStrArrayFrom(funcArray.Slice())
contentFuncArray = garray.NewSortedStrArray()
@@ -121,6 +136,12 @@ func (c CGenService) isToGenerateServiceGoFile(dstPackageName, filePath string,
if fileContent == "" {
return true
}
// remove all comments.
fileContent, err = gregex.ReplaceString(`/[/|\*](.*)`, "", fileContent)
if err != nil {
panic(err)
return false
}
matches, _ := gregex.MatchAllString(`\s+interface\s+{([\s\S]+?)}`, fileContent)
for _, match := range matches {
contentFuncArray.Append(gstr.SplitAndTrim(match[1], "\n")...)
@@ -145,29 +166,30 @@ func (c CGenService) isToGenerateServiceGoFile(dstPackageName, filePath string,
return false
}
// generateInitializationFile generates `logic.go`.
func (c CGenService) generateInitializationFile(in CGenServiceInput, importSrcPackages []string) (err error) {
var (
srcPackageName = gstr.ToLower(gfile.Basename(in.SrcFolder))
srcFilePath = gfile.Join(in.SrcFolder, srcPackageName+".go")
srcImports string
logicPackageName = gstr.ToLower(gfile.Basename(in.SrcFolder))
logicFilePath = gfile.Join(in.SrcFolder, logicPackageName+".go")
logicImports string
generatedContent string
)
if !utils.IsFileDoNotEdit(srcFilePath) {
mlog.Debugf(`ignore file as it is manually maintained: %s`, srcFilePath)
if !utils.IsFileDoNotEdit(logicFilePath) {
mlog.Debugf(`ignore file as it is manually maintained: %s`, logicFilePath)
return nil
}
for _, importSrcPackage := range importSrcPackages {
srcImports += fmt.Sprintf(`%s_ "%s"%s`, "\t", importSrcPackage, "\n")
logicImports += fmt.Sprintf(`%s_ "%s"%s`, "\t", importSrcPackage, "\n")
}
generatedContent = gstr.ReplaceByMap(consts.TemplateGenServiceLogicContent, g.MapStrStr{
"{PackageName}": srcPackageName,
"{Imports}": srcImports,
"{PackageName}": logicPackageName,
"{Imports}": logicImports,
})
mlog.Printf(`generating init go file: %s`, srcFilePath)
if err = gfile.PutContents(srcFilePath, generatedContent); err != nil {
mlog.Printf(`generating init go file: %s`, logicFilePath)
if err = gfile.PutContents(logicFilePath, generatedContent); err != nil {
return err
}
utils.GoFmt(srcFilePath)
utils.GoFmt(logicFilePath)
return nil
}

View File

@@ -0,0 +1,30 @@
package testdata
import (
"fmt"
"testing"
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/ghttp"
"github.com/gogf/gf/v2/test/gtest"
"github.com/gogf/gf/v2/util/guid"
)
func Test_Router_Hook_Multi(t *testing.T) {
s := g.Server(guid.S())
s.BindHandler("/multi-hook", func(r *ghttp.Request) {
r.Response.Write("show")
})
s.BindHookHandlerByMap("/multi-hook", map[string]ghttp.HandlerFunc{
ghttp.HookBeforeServe: func(r *ghttp.Request) {
r.Response.Write("1")
},
})
s.BindHookHandlerByMap("/multi-hook/{id}", map[string]ghttp.HandlerFunc{
ghttp.HookBeforeServe: func(r *ghttp.Request) {
r.Response.Write("2")
},
})
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const (

View File

@@ -0,0 +1,65 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplateGenCtrlControllerEmpty = `
package {Module}
`
const TemplateGenCtrlControllerNewEmpty = `
// =================================================================================
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// =================================================================================
package {Module}
import (
{ImportPath}
)
`
const TemplateGenCtrlControllerNewFunc = `
type {CtrlName} struct{}
func {NewFuncName}() {InterfaceName} {
return &{CtrlName}{}
}
`
const TemplateGenCtrlControllerMethodFunc = `
package {Module}
import (
"context"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"{ImportPath}"
)
func (c *{CtrlName}) {MethodName}(ctx context.Context, req *{Version}.{MethodName}Req) (res *{Version}.{MethodName}Res, err error) {
return nil, gerror.NewCode(gcode.CodeNotImplemented)
}
`
const TemplateGenCtrlApiInterface = `
// =================================================================================
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// =================================================================================
package {Module}
import (
{ImportPaths}
)
{Interfaces}
`

View File

@@ -0,0 +1,94 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplateGenCtrlSdkPkgNew = `
// =================================================================================
// This is auto-generated by GoFrame CLI tool only once. Fill this file as you wish.
// =================================================================================
package {PkgName}
import (
"fmt"
"github.com/gogf/gf/contrib/sdk/httpclient/v2"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/text/gstr"
)
type implementer struct {
config httpclient.Config
}
func New(config httpclient.Config) iClient {
if !gstr.HasPrefix(config.URL, "http") {
config.URL = fmt.Sprintf("http://%s", config.URL)
}
if config.Logger == nil {
config.Logger = g.Log()
}
return &implementer{
config: config,
}
}
`
const TemplateGenCtrlSdkIClient = `
// =================================================================================
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// =================================================================================
package {PkgName}
import (
)
type iClient interface {
}
`
const TemplateGenCtrlSdkImplementer = `
// =================================================================================
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// =================================================================================
package {PkgName}
import (
"context"
"github.com/gogf/gf/contrib/sdk/httpclient/v2"
"github.com/gogf/gf/v2/text/gstr"
{ImportPaths}
)
type implementer{ImplementerName} struct {
*httpclient.Client
}
`
const TemplateGenCtrlSdkImplementerNew = `
func (i *implementer) {ImplementerName}() {Module}.I{ImplementerName} {
var (
client = httpclient.New(i.config)
prefix = gstr.TrimRight(i.config.URL, "/") + "{VersionPrefix}"
)
client.Client = client.Prefix(prefix)
return &implementer{ImplementerName}{client}
}
`
const TemplateGenCtrlSdkImplementerFunc = `
func (i *implementer{ImplementerName}) {MethodName}(ctx context.Context, req *{Version}.{MethodName}Req) (res *{Version}.{MethodName}Res, err error) {
err = i.Request(ctx, req, &res)
return
}
`

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplateGenDaoIndexContent = `
@@ -33,7 +39,7 @@ var (
const TemplateGenDaoInternalContent = `
// ==========================================================================
// Code generated by GoFrame CLI tool. DO NOT EDIT. {TplCreatedAtDatetimeStr}
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. {TplCreatedAtDatetimeStr}
// ==========================================================================
package internal

View File

@@ -1,8 +1,14 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplateGenDaoDoContent = `
// =================================================================================
// Code generated by GoFrame CLI tool. DO NOT EDIT. {TplCreatedAtDatetimeStr}
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. {TplCreatedAtDatetimeStr}
// =================================================================================
package do

View File

@@ -1,8 +1,14 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplateGenDaoEntityContent = `
// =================================================================================
// Code generated by GoFrame CLI tool. DO NOT EDIT. {TplCreatedAtDatetimeStr}
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT. {TplCreatedAtDatetimeStr}
// =================================================================================
package entity

View File

@@ -8,7 +8,7 @@ package consts
const TemplateGenEnums = `
// ================================================================================
// Code generated by GoFrame CLI tool. DO NOT EDIT.
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// ================================================================================
package {PackageName}

View File

@@ -1,17 +1,23 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplatePbEntityMessageContent = `
// ==========================================================================
// Code generated by GoFrame CLI tool. DO NOT EDIT.
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// ==========================================================================
syntax = "proto3";
package {PackageName};
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
option go_package = "{GoPackage}";
{OptionContent}
{Imports}
{EntityMessage}
`

View File

@@ -1,8 +1,14 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplateGenServiceContentHead = `
// ================================================================================
// Code generated by GoFrame CLI tool. DO NOT EDIT.
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// You can delete these comments if you wish manually maintain this interface file.
// ================================================================================

View File

@@ -1,8 +1,14 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package consts
const TemplateGenServiceLogicContent = `
// ==========================================================================
// Code generated by GoFrame CLI tool. DO NOT EDIT.
// Code generated and maintained by GoFrame CLI tool. DO NOT EDIT.
// ==========================================================================
package {PackageName}

View File

@@ -1 +1,7 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package packed

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package service
import (
@@ -29,6 +35,7 @@ type serviceInstallAvailablePath struct {
filePath string
writable bool
installed bool
IsSelf bool
}
func (s serviceInstall) Run(ctx context.Context) (err error) {
@@ -104,6 +111,8 @@ func (s serviceInstall) Run(ctx context.Context) (err error) {
)
if input != "" {
inputID = gconv.Int(input)
} else {
break
}
// Check if out of range.
if inputID >= len(paths) || inputID < 0 {
@@ -131,17 +140,17 @@ func (s serviceInstall) Run(ctx context.Context) (err error) {
}
// IsInstalled checks and returns whether the binary is installed.
func (s serviceInstall) IsInstalled() bool {
func (s serviceInstall) IsInstalled() (*serviceInstallAvailablePath, bool) {
paths := s.getAvailablePaths()
for _, aPath := range paths {
if aPath.installed {
return true
return &aPath, true
}
}
return false
return nil, false
}
// getGoPathBinFilePath retrieves ad returns the GOPATH/bin path for binary.
// getGoPathBin retrieves ad returns the GOPATH/bin path for binary.
func (s serviceInstall) getGoPathBin() string {
if goPath := genv.Get(`GOPATH`).String(); goPath != "" {
return gfile.Join(goPath, "bin")
@@ -214,6 +223,7 @@ func (s serviceInstall) checkAndAppendToAvailablePath(folderPaths []serviceInsta
filePath = gfile.Join(dirPath, binaryFileName)
writable = gfile.IsWritable(dirPath)
installed = gfile.Exists(filePath)
self = gfile.SelfPath() == filePath
)
if !writable && !installed {
return folderPaths
@@ -225,5 +235,6 @@ func (s serviceInstall) checkAndAppendToAvailablePath(folderPaths []serviceInsta
writable: writable,
filePath: filePath,
installed: installed,
IsSelf: self,
})
}

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package allyes
import (

View File

@@ -1,3 +1,9 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package mlog
import (

View File

@@ -1,10 +1,20 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import (
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/text/gstr"
"context"
"fmt"
"golang.org/x/tools/imports"
"github.com/gogf/gf/v2/os/gfile"
"github.com/gogf/gf/v2/os/gproc"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"hotgo/internal/library/hggen/internal/consts"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
@@ -36,6 +46,13 @@ func GoFmt(path string) {
}
}
// GoModTidy executes `go mod tidy` at specified directory `dirPath`.
func GoModTidy(ctx context.Context, dirPath string) error {
command := fmt.Sprintf(`cd %s && go mod tidy`, dirPath)
err := gproc.ShellRun(ctx, command)
return err
}
// IsFileDoNotEdit checks and returns whether file contains `do not edit` key.
func IsFileDoNotEdit(filePath string) bool {
if !gfile.Exists(filePath) {
@@ -43,3 +60,78 @@ func IsFileDoNotEdit(filePath string) bool {
}
return gstr.Contains(gfile.GetContents(filePath), consts.DoNotEditKey)
}
// ReplaceGeneratedContentGFV2 replaces generated go content from goframe v1 to v2.
func ReplaceGeneratedContentGFV2(folderPath string) (err error) {
return gfile.ReplaceDirFunc(func(path, content string) string {
if gstr.Contains(content, `"github.com/gogf/gf`) && !gstr.Contains(content, `"github.com/gogf/gf/v2`) {
content = gstr.Replace(content, `"github.com/gogf/gf"`, `"github.com/gogf/gf/v2"`)
content = gstr.Replace(content, `"github.com/gogf/gf/`, `"github.com/gogf/gf/v2/`)
content = gstr.Replace(content, `"github.com/gogf/gf/v2/contrib/`, `"github.com/gogf/gf/contrib/`)
return content
}
return content
}, folderPath, "*.go", true)
}
// GetImportPath calculates and returns the golang import path for given `filePath`.
// Note that it needs a `go.mod` in current working directory or parent directories to detect the path.
func GetImportPath(filePath string) string {
// If `filePath` does not exist, create it firstly to find the import path.
var realPath = gfile.RealPath(filePath)
if realPath == "" {
_ = gfile.Mkdir(filePath)
realPath = gfile.RealPath(filePath)
}
var (
newDir = gfile.Dir(realPath)
oldDir string
suffix string
goModName = "go.mod"
goModPath string
importPath string
)
if gfile.IsDir(filePath) {
suffix = gfile.Basename(filePath)
}
for {
goModPath = gfile.Join(newDir, goModName)
if gfile.Exists(goModPath) {
match, _ := gregex.MatchString(`^module\s+(.+)\s*`, gfile.GetContents(goModPath))
importPath = gstr.Trim(match[1]) + "/" + suffix
importPath = gstr.Replace(importPath, `\`, `/`)
importPath = gstr.TrimRight(importPath, `/`)
return importPath
}
oldDir = newDir
newDir = gfile.Dir(oldDir)
if newDir == oldDir {
return ""
}
suffix = gfile.Basename(oldDir) + "/" + suffix
}
}
// GetModPath retrieves and returns the file path of go.mod for current project.
func GetModPath() string {
var (
oldDir = gfile.Pwd()
newDir = gfile.Dir(oldDir)
goModName = "go.mod"
goModPath string
)
for {
goModPath = gfile.Join(newDir, goModName)
if gfile.Exists(goModPath) {
return goModPath
}
oldDir = newDir
newDir = gfile.Dir(oldDir)
if newDir == oldDir {
break
}
}
return ""
}

View File

@@ -0,0 +1,105 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils
import (
"fmt"
"io"
"net/http"
"os"
"strconv"
"time"
"github.com/gogf/gf/v2/errors/gerror"
"hotgo/internal/library/hggen/internal/utility/mlog"
)
// HTTPDownloadFileWithPercent downloads target url file to local path with percent process printing.
func HTTPDownloadFileWithPercent(url string, localSaveFilePath string) error {
start := time.Now()
out, err := os.Create(localSaveFilePath)
if err != nil {
return gerror.Wrapf(err, `download "%s" to "%s" failed`, url, localSaveFilePath)
}
defer out.Close()
headResp, err := http.Head(url)
if err != nil {
return gerror.Wrapf(err, `download "%s" to "%s" failed`, url, localSaveFilePath)
}
defer headResp.Body.Close()
size, err := strconv.Atoi(headResp.Header.Get("Content-Length"))
if err != nil {
return gerror.Wrap(err, "retrieve Content-Length failed")
}
doneCh := make(chan int64)
go doPrintDownloadPercent(doneCh, localSaveFilePath, int64(size))
resp, err := http.Get(url)
if err != nil {
return gerror.Wrapf(err, `download "%s" to "%s" failed`, url, localSaveFilePath)
}
defer resp.Body.Close()
wroteBytesCount, err := io.Copy(out, resp.Body)
if err != nil {
return gerror.Wrapf(err, `download "%s" to "%s" failed`, url, localSaveFilePath)
}
doneCh <- wroteBytesCount
elapsed := time.Since(start)
if elapsed > time.Minute {
mlog.Printf(`download completed in %.0fm`, float64(elapsed)/float64(time.Minute))
} else {
mlog.Printf(`download completed in %.0fs`, elapsed.Seconds())
}
return nil
}
func doPrintDownloadPercent(doneCh chan int64, localSaveFilePath string, total int64) {
var (
stop = false
lastPercentFmt string
)
for {
select {
case <-doneCh:
stop = true
default:
file, err := os.Open(localSaveFilePath)
if err != nil {
mlog.Fatal(err)
}
fi, err := file.Stat()
if err != nil {
mlog.Fatal(err)
}
size := fi.Size()
if size == 0 {
size = 1
}
var (
percent = float64(size) / float64(total) * 100
percentFmt = fmt.Sprintf(`%.0f`, percent) + "%"
)
if lastPercentFmt != percentFmt {
lastPercentFmt = percentFmt
mlog.Print(percentFmt)
}
}
if stop {
break
}
time.Sleep(time.Second)
}
}

View File

@@ -0,0 +1,22 @@
// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package utils_test
import (
"fmt"
"testing"
"github.com/gogf/gf/v2/test/gtest"
"hotgo/internal/library/hggen/internal/utility/utils"
)
func Test_GetModPath(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
goModPath := utils.GetModPath()
fmt.Println(goModPath)
})
}

View File

@@ -3,7 +3,6 @@
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
//
package views
import (
@@ -20,7 +19,7 @@ import (
)
// DoTableColumns 获取指定表生成字段列表
func DoTableColumns(ctx context.Context, in sysin.GenCodesColumnListInp, config gendao.CGenDaoInput) (fields []*sysin.GenCodesColumnListModel, err error) {
func DoTableColumns(ctx context.Context, in *sysin.GenCodesColumnListInp, config gendao.CGenDaoInput) (fields []*sysin.GenCodesColumnListModel, err error) {
var (
sql = "select ORDINAL_POSITION as `id`, COLUMN_NAME as `name`, COLUMN_COMMENT as `dc`, DATA_TYPE as `dataType`, COLUMN_TYPE as `sqlType`, CHARACTER_MAXIMUM_LENGTH as `length`, IS_NULLABLE as `isAllowNull`, COLUMN_DEFAULT as `defaultValue`, COLUMN_KEY as `index`, EXTRA as `extra` from information_schema.COLUMNS where TABLE_SCHEMA = '%s' and TABLE_NAME = '%s' ORDER BY `id` ASC"
conf = g.DB(in.Name).GetConfig()
@@ -42,7 +41,6 @@ func DoTableColumns(ctx context.Context, in sysin.GenCodesColumnListInp, config
CustomAttributes(ctx, field, config)
}
}
return
}
@@ -65,10 +63,7 @@ func CustomAttributes(ctx context.Context, field *sysin.GenCodesColumnListModel,
// GenGotype 生成字段的go类型
func GenGotype(ctx context.Context, field *sysin.GenCodesColumnListModel, in gendao.CGenDaoInput) (goName, typeName, tsName, tsType string) {
var (
err error
)
var err error
tsName = getJsonTagFromCase(field.Name, in.JsonCase)
goName = gstr.CaseCamel(field.Name)
@@ -100,7 +95,6 @@ func GenGotype(ctx context.Context, field *sysin.GenCodesColumnListModel, in gen
}
tsType = ShiftMap[typeName]
return
}

View File

@@ -317,5 +317,10 @@ func setDefaultValue(field *sysin.GenCodesColumnListModel) {
value = consts.ConvType(field.DefaultValue, field.GoType)
}
// 时间类型不做默认值处理
if field.GoType == GoTypeGTime {
value = ""
}
field.DefaultValue = value
}

View File

@@ -7,7 +7,6 @@ package views
import (
"context"
"fmt"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gfile"
@@ -67,14 +66,14 @@ type CurdOptions struct {
HeadOps []string `json:"headOps"`
Join []*CurdOptionsJoin `json:"join"`
Menu *CurdOptionsMenu `json:"menu"`
Step *CurdStep // 转换后的流程控制条件
dictMap g.Map // 字典选项 -> 字段映射关系
TemplateGroup string `json:"templateGroup"`
ApiPrefix string `json:"apiPrefix"`
Step *CurdStep // 转换后的流程控制条件
dictMap g.Map // 字典选项 -> 字段映射关系
}
type CurdPreviewInput struct {
In sysin.GenCodesPreviewInp // 提交参数
In *sysin.GenCodesPreviewInp // 提交参数
DaoConfig gendao.CGenDaoInput // 生成dao配置
Config *model.GenerateConfig // 生成配置
view *gview.View // 视图模板
@@ -102,7 +101,7 @@ func (l *gCurd) initInput(ctx context.Context, in *CurdPreviewInput) (err error)
}
if len(in.masterFields) == 0 {
if in.masterFields, err = DoTableColumns(ctx, sysin.GenCodesColumnListInp{Name: in.In.DbName, Table: in.In.TableName}, in.DaoConfig); err != nil {
if in.masterFields, err = DoTableColumns(ctx, &sysin.GenCodesColumnListInp{Name: in.In.DbName, Table: in.In.TableName}, in.DaoConfig); err != nil {
return
}
}
@@ -236,7 +235,7 @@ func (l *gCurd) DoBuild(ctx context.Context, in *CurdBuildInput) (err error) {
for name, f := range in.BeforeEvent {
if gstr.InArray(in.PreviewIn.options.AutoOps, name) {
if err = f(ctx); err != nil {
return fmt.Errorf("in doBuild operation beforeEvent to '%s' failed::%v", name, err)
return gerror.Newf("in doBuild operation beforeEvent to '%s' failed:%v", name, err)
}
}
}
@@ -254,7 +253,7 @@ func (l *gCurd) DoBuild(ctx context.Context, in *CurdBuildInput) (err error) {
}
if err = gfile.PutContents(vi.Path, strings.TrimSpace(vi.Content)); err != nil {
return fmt.Errorf("writing content to '%s' failed: %v", vi.Path, err)
return gerror.Newf("writing content to '%s' failed: %v", vi.Path, err)
}
if gstr.Str(vi.Path, `.`) == ".sql" && needExecSql {
@@ -266,7 +265,6 @@ func (l *gCurd) DoBuild(ctx context.Context, in *CurdBuildInput) (err error) {
if gstr.Str(vi.Path, `.`) == ".go" {
utils.GoFmt(vi.Path)
}
}
// 后置操作
@@ -274,7 +272,7 @@ func (l *gCurd) DoBuild(ctx context.Context, in *CurdBuildInput) (err error) {
for name, f := range in.AfterEvent {
if gstr.InArray(in.PreviewIn.options.AutoOps, name) {
if err = f(ctx); err != nil {
return fmt.Errorf("in doBuild operation afterEvent to '%s' failed::%v", name, err)
return gerror.Newf("in doBuild operation afterEvent to '%s' failed:%v", name, err)
}
}
}
@@ -515,12 +513,12 @@ func (l *gCurd) generateWebModelContent(ctx context.Context, in *CurdPreviewInpu
tplData, err := l.webModelTplData(ctx, in)
if err != nil {
return err
return
}
genFile.Content, err = in.view.Parse(ctx, name+".template", tplData)
if err != nil {
return err
return
}
genFile.Path = file.MergeAbs(in.Config.Application.Crud.Templates[in.In.GenTemplate].WebViewsPath, gstr.LcFirst(in.In.VarName), "model.ts")
@@ -564,7 +562,6 @@ func (l *gCurd) generateWebIndexContent(ctx context.Context, in *CurdPreviewInpu
genFile.Meth = consts.GenCodesBuildMethCover
}
in.content.Views[name] = genFile
return
}

View File

@@ -164,7 +164,6 @@ func (l *gCurd) generateStructFieldDefinition(field *sysin.GenCodesColumnListMod
default:
panic("inputType is invalid")
}
return result
}

View File

@@ -48,7 +48,7 @@ func (l *gCurd) generateWebEditFormItem(ctx context.Context, in *CurdPreviewInpu
component = fmt.Sprintf("<n-form-item label=\"%s\" path=\"%s\">\n <n-input type=\"textarea\" placeholder=\"%s\" v-model:value=\"params.%s\" />\n </n-form-item>", field.Dc, field.TsName, field.Dc, field.TsName)
case FormModeInputEditor:
component = fmt.Sprintf("<n-form-item label=\"%s\" path=\"%s\">\n <Editor style=\"height: 450px\" v-model:value=\"params.%s\" />\n </n-form-item>", field.Dc, field.TsName, field.TsName)
component = fmt.Sprintf("<n-form-item label=\"%s\" path=\"%s\">\n <Editor style=\"height: 450px\" id=\"%s\" v-model:value=\"params.%s\" />\n </n-form-item>", field.Dc, field.TsName, field.TsName, field.TsName)
case FormModeInputDynamic:
component = fmt.Sprintf("<n-form-item label=\"%s\" path=\"%s\">\n <n-dynamic-input\n v-model:value=\"params.%s\"\n preset=\"pair\"\n key-placeholder=\"键名\"\n value-placeholder=\"键值\"\n />\n </n-form-item>", field.Dc, field.TsName, field.TsName)

View File

@@ -83,6 +83,5 @@ func (l *gCurd) webIndexTplData(ctx context.Context, in *CurdPreviewInput) (g.Ma
}
}
data["isSearchForm"] = isSearchForm
return data, nil
}

View File

@@ -25,7 +25,6 @@ func (l *gCurd) parseServFunName(templateGroup, varName string) string {
if gstr.HasPrefix(varName, templateGroup) && varName != templateGroup {
return varName
}
return templateGroup + varName
}
@@ -94,7 +93,6 @@ func ImportSql(ctx context.Context, path string) error {
return err
}
}
return nil
}
@@ -144,7 +142,6 @@ func checkCurdPath(temp *model.GenerateAppCrudTemplate, addonName string) (err e
if !gfile.Exists(temp.WebViewsPath) {
return gerror.Newf(tip, "WebViewsPath", temp.WebViewsPath)
}
return
}

View File

@@ -29,6 +29,10 @@ type Join struct {
fields map[string]*gdb.TableField // 表字段列表
}
func LeftJoin(m *gdb.Model, masterTable, masterField, joinTable, alias, onField string) *gdb.Model {
return m.LeftJoin(GenJoinOnRelation(masterTable, masterField, joinTable, alias, onField)...)
}
// GenJoinOnRelation 生成关联表关联条件
func GenJoinOnRelation(masterTable, masterField, joinTable, alias, onField string) []string {
relation := fmt.Sprintf("`%s`.`%s` = `%s`.`%s`", alias, onField, masterTable, masterField)

View File

@@ -7,19 +7,33 @@ package tcp
import (
"context"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/os/gtime"
"hotgo/utility/simple"
"reflect"
"sync"
"time"
)
// Client tcp客户端
type Client struct {
ctx context.Context // 上下文
conn *Conn // 连接对象
config *ClientConfig // 配置
msgParser *MsgParser // 消息处理器
logger *glog.Logger // 日志处理器
isLogin *gtype.Bool // 是否已登录
taskGo *grpool.Pool // 任务协程池
closeFlag *gtype.Bool // 关闭标签,关闭以后可以重连
stopFlag *gtype.Bool // 停止标签,停止以后不能重连
wg sync.WaitGroup // 流程控制
mutex sync.Mutex // 状态锁
}
// ClientConfig 客户端配置
type ClientConfig struct {
Addr string // 连接地址
@@ -33,86 +47,56 @@ type ClientConfig struct {
CloseEvent CallbackEvent // 连接关闭事件
}
// Client 客户端
type Client struct {
Ctx context.Context // 上下文
Logger *glog.Logger // 日志处理器
IsLogin bool // 是否已登录
addr string // 连接地址
auth *AuthMeta // 认证元数据
rpc *Rpc // rpc协议支持
timeout time.Duration // 连接超时时间
connectInterval time.Duration // 重连时间间隔
maxConnectCount uint // 最大重连次数0不限次数
connectCount uint // 已重连次数
autoReconnect bool // 是否开启自动重连
loginEvent CallbackEvent // 登录成功事件
closeEvent CallbackEvent // 连接关闭事件
sync.Mutex // 状态锁
heartbeat int64 // 心跳
msgGo *grpool.Pool // 消息处理协程池
routers map[string]RouterHandler // 已注册的路由
conn *gtcp.Conn // 连接对象
wg sync.WaitGroup // 状态控制
closeFlag bool // 关闭标签,关闭以后可以重连
stopFlag bool // 停止标签,停止以后不能重连
}
// CallbackEvent 回调事件
type CallbackEvent func()
// NewClient 初始化一个tcp客户端
func NewClient(config *ClientConfig) (client *Client, err error) {
func NewClient(config *ClientConfig) (client *Client) {
client = new(Client)
client.ctx = gctx.New()
client.logger = g.Log("tcpClient")
baseErr := gerror.New("NewClient fail")
if config == nil {
err = gerror.New("config is nil")
client.logger.Fatal(client.ctx, gerror.Wrap(baseErr, "config is nil"))
return
}
if config.Addr == "" {
err = gerror.New("client address is not set")
client.logger.Fatal(client.ctx, gerror.Wrap(baseErr, "client address is not set"))
return
}
if config.Auth == nil {
err = gerror.New("client auth cannot be empty")
return
}
if config.Auth.Group == "" || config.Auth.Name == "" {
err = gerror.New("Auth.Group or Auth.Group is nil")
return
}
client.Ctx = gctx.New()
client.autoReconnect = true
client.addr = config.Addr
client.auth = config.Auth
client.loginEvent = config.LoginEvent
client.closeEvent = config.CloseEvent
client.Logger = g.Log("tcpClient")
client.config = config
if config.ConnectInterval <= 0 {
client.connectInterval = 5 * time.Second
client.config.ConnectInterval = 5 * time.Second
} else {
client.connectInterval = config.ConnectInterval
client.config.ConnectInterval = config.ConnectInterval
}
if config.Timeout <= 0 {
client.timeout = 10 * time.Second
client.config.Timeout = 10 * time.Second
} else {
client.timeout = config.Timeout
client.config.Timeout = config.Timeout
}
client.msgGo = grpool.New(5)
client.rpc = NewRpc(client.Ctx, client.msgGo, client.Logger)
client.isLogin = gtype.NewBool(false)
client.closeFlag = gtype.NewBool(false)
client.stopFlag = gtype.NewBool(false)
client.taskGo = grpool.New(5)
client.msgParser = NewMsgParser(client.handleRoutineTask)
client.registerDefaultRouter()
return
}
// Start 启动tcp连接
func (client *Client) Start() (err error) {
client.Lock()
defer client.Unlock()
client.mutex.Lock()
defer client.mutex.Unlock()
if client.stopFlag {
if client.stopFlag.Val() {
err = gerror.New("client is stop")
return
}
@@ -121,64 +105,66 @@ func (client *Client) Start() (err error) {
return gerror.New("client is running")
}
client.IsLogin = false
client.connectCount = 0
client.closeFlag = false
client.stopFlag = false
client.isLogin.Set(false)
client.config.ConnectCount = 0
client.closeFlag.Set(false)
client.stopFlag.Set(false)
client.wg.Add(1)
simple.SafeGo(client.Ctx, func(ctx context.Context) {
simple.SafeGo(client.ctx, func(ctx context.Context) {
client.connect()
})
return
}
// registerDefaultRouter 注册默认路由
func (client *Client) registerDefaultRouter() {
var routers = []interface{}{
client.onResponseServerLogin, // 服务登录
client.onResponseServerHeartbeat, // 心跳
}
client.RegisterRouter(routers...)
}
// RegisterRouter 注册路由
func (client *Client) RegisterRouter(routers map[string]RouterHandler) (err error) {
if client.conn != nil {
return gerror.New("client is running")
func (client *Client) RegisterRouter(routers ...interface{}) {
err := client.msgParser.RegisterRouter(routers...)
if err != nil {
client.logger.Fatal(client.ctx, err)
}
}
client.Lock()
defer client.Unlock()
if client.routers == nil {
client.routers = make(map[string]RouterHandler)
// 默认路由
client.routers = map[string]RouterHandler{
"ResponseServerHeartbeat": client.onResponseServerHeartbeat,
"ResponseServerLogin": client.onResponseServerLogin,
}
// RegisterRPCRouter 注册RPC路由
func (client *Client) RegisterRPCRouter(routers ...interface{}) {
err := client.msgParser.RegisterRPCRouter(routers...)
if err != nil {
client.logger.Fatal(client.ctx, err)
}
}
for i, router := range routers {
_, ok := client.routers[i]
if ok {
return gerror.Newf("client route duplicate registration:%v", i)
}
client.routers[i] = router
}
return
// RegisterInterceptor 注册拦截器
func (client *Client) RegisterInterceptor(interceptors ...Interceptor) {
client.msgParser.RegisterInterceptor(interceptors...)
}
// dial
func (client *Client) dial() *gtcp.Conn {
for {
conn, err := gtcp.NewConn(client.addr, client.timeout)
if err == nil || client.closeFlag {
conn, err := gtcp.NewConn(client.config.Addr, client.config.Timeout)
if err == nil || client.closeFlag.Val() {
return conn
}
if client.maxConnectCount > 0 {
if client.connectCount < client.maxConnectCount {
client.connectCount += 1
if client.config.MaxConnectCount > 0 {
if client.config.ConnectCount < client.config.MaxConnectCount {
client.config.ConnectCount += 1
} else {
return nil
}
}
client.Logger.Debugf(client.Ctx, "connect to %v error: %v", client.addr, err)
time.Sleep(client.connectInterval)
client.logger.Debugf(client.ctx, "connect to %v error: %v", client.config.Addr, err)
time.Sleep(client.config.ConnectInterval)
continue
}
}
@@ -190,26 +176,24 @@ func (client *Client) connect() {
reconnect:
conn := client.dial()
if conn == nil {
if !client.stopFlag {
client.Logger.Debugf(client.Ctx, "client dial failed")
if !client.stopFlag.Val() {
client.logger.Debugf(client.ctx, "client dial failed")
}
return
}
client.Lock()
if client.closeFlag {
client.Unlock()
_ = conn.Close()
client.Logger.Debugf(client.Ctx, "client connect but closeFlag is true")
client.mutex.Lock()
if client.closeFlag.Val() {
client.mutex.Unlock()
conn.Close()
client.logger.Debugf(client.ctx, "client connect but closeFlag is true")
return
}
client.conn = conn
client.connectCount = 0
client.heartbeat = gtime.Timestamp()
client.conn = NewConn(conn, client.logger, client.msgParser)
client.config.ConnectCount = 0
client.read()
client.Unlock()
client.mutex.Unlock()
client.serverLogin()
client.startCron()
@@ -217,102 +201,38 @@ reconnect:
// read
func (client *Client) read() {
simple.SafeGo(client.Ctx, func(ctx context.Context) {
go func() {
defer func() {
client.Close()
client.Logger.Debugf(client.Ctx, "client are about to be reconnected..")
time.Sleep(client.connectInterval)
_ = client.Start()
client.logger.Debugf(client.ctx, "client are about to be reconnected..")
time.Sleep(client.config.ConnectInterval)
client.Start()
}()
for {
if client.conn == nil {
client.Logger.Debugf(client.Ctx, "client client.conn is nil, server closed")
break
}
msg, err := RecvPkg(client.conn)
if err != nil {
client.Logger.Debugf(client.Ctx, "client RecvPkg err:%+v, server closed", err)
break
}
if client.routers == nil {
client.Logger.Debugf(client.Ctx, "client RecvPkg routers is nil")
break
}
if msg == nil {
client.Logger.Debugf(client.Ctx, "client RecvPkg msg is nil")
break
}
f, ok := client.routers[msg.Router]
if !ok {
client.Logger.Debugf(client.Ctx, "client RecvPkg invalid message: %+v", msg)
continue
}
switch msg.Router {
case "ResponseServerLogin", "ResponseServerHeartbeat": // 服务登录、心跳无需验证签名
client.doHandleRouterMsg(initCtx(gctx.New(), &Context{}), f, msg.Data)
default: // 通用路由消息处理
in, err := VerifySign(msg.Data, client.auth.AppId, client.auth.SecretKey)
if err != nil {
client.Logger.Warningf(client.Ctx, "client read VerifySign err:%+v message: %+v", err, msg)
continue
}
ctx := initCtx(gctx.New(), &Context{
Conn: client.conn,
Auth: client.auth,
TraceID: in.TraceID,
})
// 响应rpc消息
if client.rpc.HandleMsg(ctx, msg.Data) {
return
}
client.doHandleRouterMsg(ctx, f, msg.Data)
}
if err := client.conn.Run(); err != nil {
client.logger.Debug(client.ctx, err)
}
})
}()
}
// Close 关闭同服务器的链接
func (client *Client) Close() {
client.Lock()
defer client.Unlock()
client.mutex.Lock()
defer client.mutex.Unlock()
client.IsLogin = false
client.closeFlag = true
client.isLogin.Set(false)
client.closeFlag.Set(true)
if client.conn != nil {
client.conn.Close()
client.conn = nil
}
if client.closeEvent != nil {
client.closeEvent()
if client.config.CloseEvent != nil {
client.config.CloseEvent()
}
client.wg.Wait()
}
// Stop 停止服务
func (client *Client) Stop() {
if client.stopFlag {
return
}
client.stopFlag = true
client.stopCron()
client.Close()
}
// IsStop 是否已停止
func (client *Client) IsStop() bool {
return client.stopFlag
}
// Destroy 销毁当前连接
func (client *Client) Destroy() {
client.stopCron()
@@ -322,80 +242,62 @@ func (client *Client) Destroy() {
}
}
// Write
func (client *Client) Write(data interface{}) error {
client.Lock()
defer client.Unlock()
if client.conn == nil {
return gerror.New("client conn is nil")
// Stop 停止服务
func (client *Client) Stop() {
if client.stopFlag.Val() {
return
}
client.stopFlag.Set(true)
client.stopCron()
client.Close()
}
if client.closeFlag {
return gerror.New("client conn is closed")
}
// IsStop 是否已停止
func (client *Client) IsStop() bool {
return client.stopFlag.Val()
}
if data == nil {
return gerror.New("client Write message is nil")
}
// IsLogin 是否已登录成功
func (client *Client) IsLogin() bool {
return client.isLogin.Val()
}
msgType := reflect.TypeOf(data)
if msgType == nil || msgType.Kind() != reflect.Ptr {
return gerror.Newf("client json message pointer required: %+v", data)
func (client *Client) handleRoutineTask(ctx context.Context, task func()) {
ctx, cancel := context.WithCancel(ctx)
err := client.taskGo.AddWithRecover(ctx,
func(ctx context.Context) {
task()
cancel()
},
func(ctx context.Context, err error) {
client.logger.Warningf(ctx, "routineTask exec err:%+v", err)
cancel()
},
)
if err != nil {
client.logger.Warningf(ctx, "routineTask add err:%+v", err)
}
msg := &Message{Router: msgType.Elem().Name(), Data: data}
return SendPkg(client.conn, msg)
}
// Conn 获取当前连接
func (client *Client) Conn() *Conn {
return client.conn
}
// Send 发送消息
func (client *Client) Send(ctx context.Context, data interface{}) error {
MsgPkg(data, client.auth, gctx.CtxId(ctx))
return client.Write(data)
}
// Reply 回复消息
func (client *Client) Reply(ctx context.Context, data interface{}) (err error) {
user := GetCtx(ctx)
if user == nil {
err = gerror.New("获取回复用户信息失败")
return
if client.conn == nil {
return gerror.New("conn is nil")
}
MsgPkg(data, client.auth, user.TraceID)
return client.Write(data)
return client.conn.Send(ctx, data)
}
// RpcRequest 发送消息并等待响应结果
func (client *Client) RpcRequest(ctx context.Context, data interface{}) (res interface{}, err error) {
var (
traceID = MsgPkg(data, client.auth, gctx.CtxId(ctx))
key = client.rpc.GetCallId(client.conn, traceID)
)
if traceID == "" {
err = gerror.New("traceID is required")
return
}
return client.rpc.Request(key, func() {
_ = client.Write(data)
})
// Request 发送消息并等待响应结果
func (client *Client) Request(ctx context.Context, data interface{}) (interface{}, error) {
return client.conn.Request(ctx, data)
}
// doHandleRouterMsg 处理路由消息
func (client *Client) doHandleRouterMsg(ctx context.Context, fun RouterHandler, args ...interface{}) {
ctx, cancel := context.WithCancel(ctx)
err := client.msgGo.AddWithRecover(ctx,
func(ctx context.Context) {
fun(ctx, args...)
cancel()
},
func(ctx context.Context, err error) {
client.Logger.Warningf(ctx, "doHandleRouterMsg msgGo exec err:%+v", err)
cancel()
},
)
if err != nil {
client.Logger.Warningf(ctx, "doHandleRouterMsg msgGo Add err:%+v", err)
return
}
// RequestScan 发送消息并等待响应结果将结果保存在response中
func (client *Client) RequestScan(ctx context.Context, data, response interface{}) error {
return client.conn.RequestScan(ctx, data, response)
}

View File

@@ -10,12 +10,11 @@ import (
"fmt"
"github.com/gogf/gf/v2/os/gcron"
"github.com/gogf/gf/v2/os/gtime"
"hotgo/internal/consts"
)
// getCronKey 生成客户端定时任务名称
func (client *Client) getCronKey(s string) string {
return fmt.Sprintf("tcp.client_%s_%s:%s", s, client.auth.Group, client.auth.Name)
return fmt.Sprintf("tcp.client_%s:%s", s, client.conn.LocalAddr().String())
}
// stopCron 停止定时任务
@@ -28,19 +27,22 @@ func (client *Client) stopCron() {
// startCron 启动定时任务
func (client *Client) startCron() {
// 心跳超时检查
if gcron.Search(client.getCronKey(consts.TCPCronHeartbeatVerify)) == nil {
_, _ = gcron.AddSingleton(client.Ctx, "@every 600s", func(ctx context.Context) {
if client.heartbeat < gtime.Timestamp()-consts.TCPHeartbeatTimeout {
client.Logger.Debugf(client.Ctx, "client heartbeat timeout, about to reconnect..")
if gcron.Search(client.getCronKey(CronHeartbeatVerify)) == nil {
_, _ = gcron.AddSingleton(client.ctx, "@every 600s", func(ctx context.Context) {
if client == nil || client.conn == nil {
return
}
if client.conn.Heartbeat < gtime.Timestamp()-HeartbeatTimeout {
client.logger.Debugf(client.ctx, "client heartbeat timeout, about to reconnect..")
client.Destroy()
}
}, client.getCronKey(consts.TCPCronHeartbeatVerify))
}, client.getCronKey(CronHeartbeatVerify))
}
// 心跳
if gcron.Search(client.getCronKey(consts.TCPCronHeartbeat)) == nil {
_, _ = gcron.AddSingleton(client.Ctx, "@every 120s", func(ctx context.Context) {
if gcron.Search(client.getCronKey(CronHeartbeat)) == nil {
_, _ = gcron.AddSingleton(client.ctx, "@every 300s", func(ctx context.Context) {
client.serverHeartbeat()
}, client.getCronKey(consts.TCPCronHeartbeat))
}, client.getCronKey(CronHeartbeat))
}
}

View File

@@ -7,70 +7,73 @@ package tcp
import (
"context"
"fmt"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/consts"
"hotgo/internal/model/input/msgin"
"hotgo/utility/encrypt"
)
// serverLogin 心跳
func (client *Client) serverHeartbeat() {
if !client.isLogin.Val() {
return
}
ctx := gctx.New()
if err := client.Send(ctx, &msgin.ServerHeartbeat{}); err != nil {
client.Logger.Debugf(ctx, "client WriteMsg ServerHeartbeat err:%+v", err)
if err := client.conn.Send(ctx, &ServerHeartbeatReq{}); err != nil {
client.logger.Warningf(ctx, "client ServerHeartbeat Send err:%+v", err)
return
}
}
// serverLogin 服务登陆
func (client *Client) serverLogin() {
data := &msgin.ServerLogin{
Group: client.auth.Group,
Name: client.auth.Name,
auth := client.config.Auth
// 无需登录
if auth == nil {
return
}
data := &ServerLoginReq{
Name: auth.Name,
Extra: auth.Extra,
Group: auth.Group,
AppId: auth.AppId,
Timestamp: gtime.Timestamp(),
}
// 签名
data.Sign = encrypt.Md5ToString(fmt.Sprintf("%v%v%v", data.AppId, data.Timestamp, auth.SecretKey))
ctx := gctx.New()
if err := client.Send(ctx, data); err != nil {
client.Logger.Debugf(ctx, "client WriteMsg ServerLogin err:%+v", err)
if err := client.conn.Send(ctx, data); err != nil {
client.logger.Warningf(ctx, "client ServerLogin Send err:%+v", err)
return
}
}
// onResponseServerLogin 接收服务登陆响应结果
func (client *Client) onResponseServerLogin(ctx context.Context, args ...interface{}) {
var in *msgin.ResponseServerLogin
if err := gconv.Scan(args[0], &in); err != nil {
client.Logger.Infof(ctx, "onResponseServerLogin message Scan failed:%+v, args:%+v", err, args[0])
return
}
if in.Code != consts.TCPMsgCodeSuccess {
client.IsLogin = false
client.Logger.Warningf(ctx, "onResponseServerLogin quit err:%v", in.Message)
func (client *Client) onResponseServerLogin(ctx context.Context, req *ServerLoginRes) {
if err := req.GetError(); err != nil {
client.isLogin.Set(false)
client.logger.Warningf(ctx, "onResponseServerLogin destroy, err:%v", err)
client.Destroy()
return
}
client.IsLogin = true
client.isLogin.Set(true)
if client.loginEvent != nil {
client.loginEvent()
if client.config.LoginEvent != nil {
client.config.LoginEvent()
}
}
// onResponseServerHeartbeat 接收心跳响应结果
func (client *Client) onResponseServerHeartbeat(ctx context.Context, args ...interface{}) {
var in *msgin.ResponseServerHeartbeat
if err := gconv.Scan(args[0], &in); err != nil {
client.Logger.Infof(ctx, "onResponseServerHeartbeat message Scan failed:%+v, args:%+v", err, args)
func (client *Client) onResponseServerHeartbeat(ctx context.Context, req *ServerHeartbeatRes) {
if err := req.GetError(); err != nil {
client.logger.Warningf(ctx, "onResponseServerHeartbeat err:%v", err)
return
}
if in.Code != consts.TCPMsgCodeSuccess {
client.Logger.Warningf(ctx, "onResponseServerHeartbeat err:%v", in.Message)
return
}
client.heartbeat = gtime.Timestamp()
client.conn.Heartbeat = gtime.Timestamp()
}

View File

@@ -0,0 +1,177 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"context"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/grand"
"net"
"sync/atomic"
)
// AuthMeta 认证元数据
type AuthMeta struct {
Name string `json:"name"` // 客户端名称当同一个应用ID有多个客户端时请使用不同的名称区分。比如cron1,cron2
Extra g.Map `json:"extra"` // 自定义数据,可以传递一些额外的自定义数据
Group string `json:"group"` // 客户端分组
AppId string `json:"appId"` // 应用ID
SecretKey string `json:"secretKey"` // 应用秘钥
EndAt *gtime.Time `json:"endAt"` // 授权过期时间
Routes []string `json:"routes"` // 授权路由
}
// Conn tcp连接
type Conn struct {
CID int64 // 连接ID
Conn *gtcp.Conn // 连接对象
Auth *AuthMeta // 认证元数据
Heartbeat int64 // 心跳
FirstTime int64 // 首次连接时间
writeChan chan []byte // 发数据
closeFlag *gtype.Bool // 关闭标签
logger *glog.Logger // 日志处理器
msgParser *MsgParser // 消息处理器
}
var idCounter int64
func NewConn(conn *gtcp.Conn, logger *glog.Logger, msgParser *MsgParser) *Conn {
tcpConn := new(Conn)
tcpConn.CID = atomic.AddInt64(&idCounter, 1)
tcpConn.Conn = conn
tcpConn.Heartbeat = gtime.Timestamp()
tcpConn.FirstTime = gtime.Timestamp()
tcpConn.writeChan = make(chan []byte, 1000)
tcpConn.closeFlag = gtype.NewBool(false)
tcpConn.logger = logger
tcpConn.msgParser = msgParser
go func() {
for b := range tcpConn.writeChan {
if b == nil {
break
}
if err := conn.SendPkg(b); err != nil {
break
}
}
}()
return tcpConn
}
func (c *Conn) Run() error {
for {
data, err := c.Conn.RecvPkg()
if err != nil {
return gerror.NewCodef(gcode.CodeInvalidRequest, "read packet err:%+v conn closed", err)
}
if c.closeFlag.Val() {
return nil
}
msg, err := c.msgParser.Encoding(data)
if err != nil {
return gerror.NewCodef(gcode.CodeInternalError, "message encoding err:%+v conn closed", err)
}
ctx, err := c.bindContext(msg)
if err != nil {
return gerror.NewCodef(gcode.CodeInternalError, "bindContext err:%+v message: %+v", err, msg)
}
if err = c.msgParser.handleInterceptor(ctx, msg); err != nil {
c.logger.Warning(ctx, gerror.Wrap(err, "interceptor authentication failed"))
continue
}
if err = c.msgParser.handleRouterMsg(ctx, msg); err != nil {
return err
}
}
}
// RemoteAddr returns the remote network address, if known.
func (c *Conn) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
}
// LocalAddr returns the local network address, if known.
func (c *Conn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
// Write
func (c *Conn) Write(b []byte) {
if !c.closeFlag.Val() {
c.writeChan <- b
}
}
// Send 发送消息
func (c *Conn) Send(ctx context.Context, data interface{}) error {
if c.closeFlag.Val() {
return gerror.New("conn is closed")
}
b, err := c.msgParser.Decoding(ctx, data, "")
if err != nil {
return err
}
c.Write(b)
return nil
}
func (c *Conn) Close() {
if c.closeFlag.Val() {
return
}
c.closeFlag.Set(true)
close(c.writeChan)
c.Conn.Close()
}
// Request 发送消息并等待响应结果
func (c *Conn) Request(ctx context.Context, data interface{}) (interface{}, error) {
if c.closeFlag.Val() {
return nil, gerror.New("conn is closed")
}
msgId := grand.S(16)
b, err := c.msgParser.Decoding(ctx, data, msgId)
if err != nil {
return nil, err
}
return c.msgParser.rpc.Request(ctx, msgId, func() {
c.Write(b)
})
}
// RequestScan 发送消息并等待响应结果将结果保存在response中
func (c *Conn) RequestScan(ctx context.Context, data, response interface{}) error {
body, err := c.Request(ctx, data)
if err != nil {
return err
}
return gvar.New(body).Scan(response)
}
// bindContext 将用户身份绑定到上下文
func (c *Conn) bindContext(msg *Message) (ctx context.Context, err error) {
ctx = initCtx(gctx.New(), &Context{
Conn: c,
})
return SetCtxTraceID(ctx, msg.TraceId)
}

View File

@@ -0,0 +1,32 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
// 定时任务
const (
CronHeartbeatVerify = "tcpHeartbeatVerify"
CronHeartbeat = "tcpHeartbeat"
CronAuthVerify = "tcpAuthVerify"
)
const (
HeartbeatTimeout = 300 // tcp心跳超时默认300s
RPCTimeout = 10 // rpc通讯超时时间 默认10s
)
const (
ParseRouterErrInvalidParams = "register router[%v] method must have two params"
ParseRouterRPCErrInvalidParams = "register RPC router [%v] method must have two response params"
ParseRouterErrInvalidFirstParam = "the first params of the processing method that registers the router[%v] must be of type context.Context"
ParseRouterErrInvalidSecondParam = "the second params of the processing method that registers the router[%v] must be of type pointer to a struct"
)
type CtxKey string
// ContextKey 上下文
const (
ContextKey CtxKey = "tcpContext" // tcp上下文变量名称
)

View File

@@ -8,23 +8,21 @@ package tcp
import (
"context"
"github.com/gogf/gf/v2/net/gtrace"
"hotgo/internal/consts"
)
// Context tcp上下文
type Context struct {
Conn *Conn
}
// initCtx 初始化上下文对象指针到上下文对象中,以便后续的请求流程中可以修改
func initCtx(ctx context.Context, model *Context) (newCtx context.Context) {
if model.TraceID != "" {
newCtx, _ = gtrace.WithTraceID(ctx, model.TraceID)
} else {
newCtx = ctx
}
newCtx = context.WithValue(newCtx, consts.ContextTCPKey, model)
return
func initCtx(ctx context.Context, model *Context) context.Context {
return context.WithValue(ctx, ContextKey, model)
}
// GetCtx 获得上下文变量如果没有设置那么返回nil
func GetCtx(ctx context.Context) *Context {
value := ctx.Value(consts.ContextTCPKey)
value := ctx.Value(ContextKey)
if value == nil {
return nil
}
@@ -33,3 +31,20 @@ func GetCtx(ctx context.Context) *Context {
}
return nil
}
// ConnFromCtx retrieves and returns the Conn object from context.
func ConnFromCtx(ctx context.Context) *Conn {
user := GetCtx(ctx)
if user == nil {
return nil
}
return user.Conn
}
// SetCtxTraceID 将自定义跟踪ID注入上下文以进行传播
func SetCtxTraceID(ctx context.Context, traceID string) (context.Context, error) {
if len(traceID) > 0 {
return gtrace.WithTraceID(ctx, traceID)
}
return ctx, nil
}

View File

@@ -1,30 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gtime"
)
// AuthMeta 认证元数据
type AuthMeta struct {
Group string `json:"group"`
Name string `json:"name"`
AppId string `json:"appId"`
SecretKey string `json:"secretKey"`
EndAt *gtime.Time `json:"-"`
}
// Context tcp上下文
type Context struct {
Conn *gtcp.Conn `json:"conn"`
Auth *AuthMeta `json:"auth"` // 认证元数据
TraceID string `json:"traceID"` // 链路ID
}
// CallbackEvent 回调事件
type CallbackEvent func()

View File

@@ -0,0 +1,84 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
)
type ServerRes struct {
Code int `json:"code" example:"2000" description:"状态码"`
Message string `json:"message,omitempty" example:"操作成功" description:"提示消息"`
}
// SetCode 设置状态码
func (i *ServerRes) SetCode(code ...int) {
if len(code) > 0 {
i.Code = code[0]
return
}
// 默认值,转为成功的状态码
if i.Code == 0 {
i.Code = gcode.CodeOK.Code()
}
}
// SetMessage 设置提示消息
func (i *ServerRes) SetMessage(msg ...string) {
message := "操作成功"
if len(msg) > 0 {
message = msg[0]
return
}
i.Message = message
}
// SetError 设置响应中的错误
func (i *ServerRes) SetError(err error) {
if err != nil {
i.Code = gerror.Code(err).Code()
i.Message = err.Error()
}
return
}
// GetError 获取响应中的错误
func (i *ServerRes) GetError() (err error) {
if i.Code != gcode.CodeOK.Code() {
if i.Message == "" {
i.Message = "操作失败"
}
err = gerror.NewCode(gcode.New(i.Code, i.Message, ""))
}
return
}
// ServerLoginReq 服务登录
type ServerLoginReq struct {
Name string `json:"name" description:"客户端名称"` // 客户端名称当同一个应用ID有多个客户端时请使用不同的名称区分。比如cron1,cron2
Extra g.Map `json:"extra" description:"自定义数据"` // 自定义数据,可以传递一些额外的自定义数据
Group string `json:"group" description:"分组"`
AppId string `json:"appID" description:"应用ID"`
Timestamp int64 `json:"timestamp" description:"服务器时间戳"`
Sign string `json:"sign" description:"签名"`
}
// ServerLoginRes 响应服务登录
type ServerLoginRes struct {
ServerRes
}
// ServerHeartbeatReq 心跳
type ServerHeartbeatReq struct {
}
// ServerHeartbeatRes 响应心跳
type ServerHeartbeatRes struct {
ServerRes
}

View File

@@ -0,0 +1,222 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"context"
"encoding/json"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
"reflect"
"sync"
)
// MsgParser 消息处理器
type MsgParser struct {
mutex sync.Mutex // 路由锁
task RoutineTask // 投递协程任务
rpc *RPC // rpc
routers map[string]*RouteHandler // 已注册的路由
interceptors []Interceptor // 拦截器
}
// RoutineTask 投递协程任务
type RoutineTask func(ctx context.Context, task func())
// Interceptor 拦截器
type Interceptor func(ctx context.Context, msg *Message) (err error)
// Message 标准消息
type Message struct {
Router string `json:"router"` // 路由
TraceId string `json:"traceId"` // 链路ID
Data interface{} `json:"data"` // 数据
MsgId string `json:"msgId,omitempty"` // 消息IDrpc用
Error string `json:"error,omitempty"` // 消息错误rpc用
}
// NewMsgParser 初始化消息处理器
func NewMsgParser(task RoutineTask) *MsgParser {
m := new(MsgParser)
m.task = task
m.routers = make(map[string]*RouteHandler)
m.rpc = NewRPC(task)
return m
}
// RegisterRouter 注册路由
func (m *MsgParser) RegisterRouter(routers ...interface{}) (err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, router := range routers {
info, err := ParseRouteHandler(router, false)
if err != nil {
return err
}
if _, ok := m.routers[info.Id]; ok {
return gerror.Newf("server router duplicate registration:%v", info.Id)
}
m.routers[info.Id] = info
}
return
}
// RegisterRPCRouter 注册rpc路由
func (m *MsgParser) RegisterRPCRouter(routers ...interface{}) (err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, router := range routers {
info, err := ParseRouteHandler(router, true)
if err != nil {
return err
}
if _, ok := m.routers[info.Id]; ok {
return gerror.Newf("server rpc router duplicate registration:%v", info.Id)
}
m.routers[info.Id] = info
}
return
}
// RegisterInterceptor 注册拦截器
func (m *MsgParser) RegisterInterceptor(interceptors ...Interceptor) {
m.interceptors = append(interceptors, interceptors...)
return
}
// Encoding 消息编码
func (m *MsgParser) Encoding(data []byte) (*Message, error) {
var msg Message
if err := gconv.Scan(data, &msg); err != nil {
return nil, gerror.Newf("invalid package struct: %s", err.Error())
}
if msg.Router == "" {
return nil, gerror.Newf("message is not router: %+v", msg)
}
return &msg, nil
}
// Decoding 消息解码
func (m *MsgParser) Decoding(ctx context.Context, data interface{}, msgId string) ([]byte, error) {
message, err := m.doDecoding(ctx, data, msgId)
if err != nil {
return nil, err
}
return json.Marshal(message)
}
// Decoding 消息解码
func (m *MsgParser) doDecoding(ctx context.Context, data interface{}, msgId string) (*Message, error) {
msgType := reflect.TypeOf(data)
if msgType == nil || msgType.Kind() != reflect.Ptr {
return nil, gerror.Newf("json message pointer required: %+v", data)
}
message := &Message{Router: msgType.Elem().Name(), TraceId: gctx.CtxId(ctx), MsgId: msgId, Data: data}
return message, nil
}
// handleRouterMsg 处理路由消息
func (m *MsgParser) handleRouterMsg(ctx context.Context, msg *Message) error {
// rpc消息
if m.rpc.Response(ctx, msg) {
return nil
}
handler, ok := m.routers[msg.Router]
if !ok {
return gerror.NewCodef(gcode.CodeInternalError, "invalid message: %+v, or the router is not registered.", msg)
}
return m.doHandleRouterMsg(ctx, handler, msg)
}
// doHandleRouterMsg 处理路由消息
func (m *MsgParser) doHandleRouterMsg(ctx context.Context, handler *RouteHandler, msg *Message) (err error) {
var input = gutil.Copy(handler.Input.Interface())
if err = gjson.New(msg.Data).Scan(input); err != nil {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
"router scan failed:%v to parse message:%v",
err, handler.Id)
}
args := []reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(input)}
m.task(ctx, func() {
results := handler.Func.Call(args)
if handler.IsRPC {
switch len(results) {
case 2:
var responseErr error
if !results[1].IsNil() {
if err, ok := results[1].Interface().(error); ok {
responseErr = err
}
}
responseMsg, deErr := m.doDecoding(ctx, results[0].Interface(), msg.MsgId)
if deErr != nil && responseErr == nil {
responseErr = deErr
return
}
if responseErr != nil {
responseMsg.Error = responseErr.Error()
}
b, err := json.Marshal(responseMsg)
if err != nil {
return
}
user := GetCtx(ctx)
user.Conn.Write(b)
}
return
}
})
return
}
// handleInterceptor 处理拦截器
func (m *MsgParser) handleInterceptor(ctx context.Context, msg *Message) (interceptErr error) {
for _, f := range m.interceptors {
if interceptErr = f(ctx, msg); interceptErr != nil {
break
}
}
if interceptErr == nil {
return
}
handler, ok := m.routers[msg.Router]
if !ok {
return
}
if handler.IsRPC {
var output = gutil.Copy(handler.Output.Interface())
response, doerr := m.doDecoding(ctx, output, msg.MsgId)
if doerr != nil {
return doerr
}
response.Error = interceptErr.Error()
b, err := json.Marshal(response)
if err != nil {
return err
}
ConnFromCtx(ctx).Write(b)
}
return
}

View File

@@ -1,27 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
type Response interface {
PkgResponse()
GetError() (err error)
}
// PkgResponse 打包响应消息
func PkgResponse(data interface{}) {
if c, ok := data.(Response); ok {
c.PkgResponse()
return
}
}
// GetResponseError 解析响应消息中的错误
func GetResponseError(data interface{}) (err error) {
if c, ok := data.(Response); ok {
return c.GetError()
}
return
}

View File

@@ -7,56 +7,88 @@ package tcp
import (
"context"
"encoding/json"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/text/gstr"
"reflect"
"runtime"
)
// RouterHandler 路由消息处理器
type RouterHandler func(ctx context.Context, args ...interface{})
// Message 路由消息
type Message struct {
Router string `json:"router"`
Data interface{} `json:"data"`
// RouteHandler 路由处理器
type RouteHandler struct {
Id string // 路由ID
IsRPC bool // 是否支持rpc协议
Func reflect.Value // 路由处理方法
Input reflect.Value // 输入参数
Output reflect.Value // 输出参数
}
// SendPkg 打包发送的数据包
func SendPkg(conn *gtcp.Conn, message *Message) error {
b, err := json.Marshal(message)
if err != nil {
return err
// ParseRouteHandler 解析路由
func ParseRouteHandler(router interface{}, isRPC bool) (info *RouteHandler, err error) {
funcName := runtime.FuncForPC(reflect.ValueOf(router).Pointer()).Name()
funcType := reflect.ValueOf(router).Type()
if funcType.NumIn() != 2 {
err = gerror.Newf(ParseRouterErrInvalidParams, funcName)
return
}
return conn.SendPkg(b)
}
// RecvPkg 解包
func RecvPkg(conn *gtcp.Conn) (*Message, error) {
if data, err := conn.RecvPkg(); err != nil {
return nil, err
} else {
var msg = new(Message)
if err = gconv.Scan(data, &msg); err != nil {
return nil, gerror.Newf("invalid package structure: %s", err.Error())
}
if msg.Router == "" {
return nil, gerror.Newf("message is not routed: %+v", msg)
}
return msg, err
if funcType.In(0) != reflect.TypeOf((*context.Context)(nil)).Elem() {
err = gerror.Newf(ParseRouterErrInvalidFirstParam, funcName)
return
}
}
// MsgPkg 打包消息
func MsgPkg(data interface{}, auth *AuthMeta, traceID string) string {
// 打包签名
msg := PkgSign(data, auth.AppId, auth.SecretKey, traceID)
// 打包响应消息
PkgResponse(data)
if msg == nil {
return ""
inputType := funcType.In(1)
if !(inputType.Kind() == reflect.Ptr && inputType.Elem().Kind() == reflect.Struct) {
err = gerror.Newf(ParseRouterErrInvalidSecondParam, funcName)
return
}
return msg.TraceID
// The request struct should be named as `xxxReq`.
if !gstr.HasSuffix(inputType.String(), `Req`) && !gstr.HasSuffix(inputType.String(), `Res`) {
err = gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid struct naming of the request: defined as "%s", but should be named with the "Req" or "Res" suffix, such as "XxxReq" or "XxxRes"`,
inputType.String(),
)
return
}
info = &RouteHandler{
Id: gstr.SubStrFromREx(inputType.String(), `.`),
IsRPC: isRPC,
Func: reflect.ValueOf(router),
Input: reflect.New(inputType.Elem()),
}
if !isRPC {
return
}
if funcType.NumOut() != 2 {
err = gerror.Newf(ParseRouterRPCErrInvalidParams, funcName)
return
}
outputType := funcType.Out(0)
// The response struct should be named as `xxxRes`.
if !gstr.HasSuffix(outputType.String(), `Res`) {
err = gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid struct naming for response: defined as "%s", but it should be named with "Res" suffix like "XxxRes"`,
outputType.String(),
)
return
}
if !funcType.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
err = gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid handler: defined as "%s", but the last output parameter should be type of "error"`,
reflect.TypeOf(funcType).String(),
)
return
}
info.Output = reflect.New(outputType.Elem())
return
}

View File

@@ -7,118 +7,95 @@ package tcp
import (
"context"
"fmt"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
"hotgo/internal/consts"
"hotgo/utility/simple"
"sync"
"time"
)
type Rpc struct {
ctx context.Context
// RPC .
type RPC struct {
mutex sync.Mutex
callbacks map[string]RpcRespFunc
msgGo *grpool.Pool // 消息处理协程池
logger *glog.Logger // 日志处理器
callbacks map[string]RPCResponseFunc
task RoutineTask
}
// RpcResp 响应结构
type RpcResp struct {
// RPCResponse 响应结构
type RPCResponse struct {
res interface{}
err error
}
type RpcRespFunc func(resp interface{}, err error)
type RPCResponseFunc func(resp interface{}, err error)
// NewRpc 初始化一个rpc协议
func NewRpc(ctx context.Context, msgGo *grpool.Pool, logger *glog.Logger) *Rpc {
return &Rpc{
ctx: ctx,
callbacks: make(map[string]RpcRespFunc),
msgGo: msgGo,
logger: logger,
// NewRPC 初始化RPC
func NewRPC(task RoutineTask) *RPC {
return &RPC{
task: task,
callbacks: make(map[string]RPCResponseFunc),
}
}
// GetCallId 获取回调id
func (r *Rpc) GetCallId(client *gtcp.Conn, traceID string) string {
return fmt.Sprintf("%v.%v", client.LocalAddr().String(), traceID)
}
// HandleMsg 处理rpc消息
func (r *Rpc) HandleMsg(ctx context.Context, data interface{}) bool {
user := GetCtx(ctx)
callId := r.GetCallId(user.Conn, user.TraceID)
if call, ok := r.callbacks[callId]; ok {
r.mutex.Lock()
delete(r.callbacks, callId)
r.mutex.Unlock()
ctx, cancel := context.WithCancel(ctx)
err := r.msgGo.AddWithRecover(ctx, func(ctx context.Context) {
call(data, nil)
cancel()
}, func(ctx context.Context, err error) {
r.logger.Warningf(ctx, "rpc HandleMsg msgGo exec err:%+v", err)
cancel()
})
if err != nil {
r.logger.Warningf(ctx, "rpc HandleMsg msgGo Add err:%+v", err)
}
return true
}
return false
}
// Request 发起rpc请求
func (r *Rpc) Request(callId string, send func()) (res interface{}, err error) {
var (
waitCh = make(chan struct{})
resCh = make(chan RpcResp, 1)
isClose = false
)
// Request 发起RPC请求
func (r *RPC) Request(ctx context.Context, msgId string, send func()) (res interface{}, err error) {
resCh := make(chan RPCResponse, 1)
isClose := gtype.NewBool(false)
defer func() {
isClose = true
isClose.Set(true)
close(resCh)
// 移除消息
if _, ok := r.callbacks[callId]; ok {
r.mutex.Lock()
delete(r.callbacks, callId)
r.mutex.Unlock()
}
r.popCallback(msgId)
}()
simple.SafeGo(r.ctx, func(ctx context.Context) {
close(waitCh)
// 加入回调
r.mutex.Lock()
r.callbacks[callId] = func(res interface{}, err error) {
if !isClose {
resCh <- RpcResp{res: res, err: err}
}
r.mutex.Lock()
r.callbacks[msgId] = func(res interface{}, err error) {
if !isClose.Val() {
resCh <- RPCResponse{res: res, err: err}
}
r.mutex.Unlock()
}
r.mutex.Unlock()
// 发送消息
send()
})
r.task(ctx, send)
<-waitCh
select {
case <-time.After(time.Second * consts.TCPRpcTimeout):
err = gerror.New("rpc response timeout")
case <-time.After(time.Second * RPCTimeout):
err = gerror.New("RPC response timeout")
return
case got := <-resCh:
return got.res, got.err
}
}
// Response RPC消息响应
func (r *RPC) Response(ctx context.Context, msg *Message) bool {
if len(msg.MsgId) == 0 {
return false
}
f, ok := r.popCallback(msg.MsgId)
if !ok {
return false
}
var msgError error
if len(msg.Error) > 0 {
msgError = gerror.New(msg.Error)
}
r.task(ctx, func() {
f(msg.Data, msgError)
})
return true
}
// popCallback 弹出回调
func (r *RPC) popCallback(msgId string) (RPCResponseFunc, bool) {
r.mutex.Lock()
defer r.mutex.Unlock()
call, ok := r.callbacks[msgId]
if ok {
delete(r.callbacks, msgId)
}
return call, ok
}

View File

@@ -7,24 +7,30 @@ package tcp
import (
"context"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
"hotgo/internal/consts"
"hotgo/utility/simple"
"reflect"
"sync"
"time"
)
// ClientConn 连接到tcp服务器的客户端对象
type ClientConn struct {
Conn *gtcp.Conn // 连接对象
Auth *AuthMeta // 认证元数据
heartbeat int64 // 心跳
// Server tcp服务器
type Server struct {
ctx context.Context // 上下文
name string // 服务器名称
addr string // 服务器地址
ln *gtcp.Server // tcp服务器
logger *glog.Logger // 日志处理器
wgLn sync.WaitGroup // 流程控制让tcp服务器按可控流程启动退出
closeFlag *gtype.Bool // 服务关闭标签
clients map[string]*Conn // 已登录的认证客户端
mutexConns sync.Mutex // 连接锁,主要用于客户端上下线
taskGo *grpool.Pool // 任务协程池
msgParser *MsgParser // 消息处理器
}
// ServerConfig tcp服务器配置
@@ -33,182 +39,121 @@ type ServerConfig struct {
Addr string // 监听地址
}
// Server tcp服务器对象结构
type Server struct {
Ctx context.Context // 上下文
Logger *glog.Logger // 日志处理器
addr string // 连接地址
name string // 服务器名称
rpc *Rpc // rpc协议
ln *gtcp.Server // tcp服务器
wgLn sync.WaitGroup // 状态控制主要用于tcp服务器能够按流程启动退出
mutex sync.Mutex // 服务器状态锁
closeFlag bool // 服务关闭标签
clients map[string]*ClientConn // 已登录的认证客户端
mutexConns sync.Mutex // 连接锁,主要用于客户端上下线
msgGo *grpool.Pool // 消息处理协程池
cronRouters map[string]RouterHandler // 定时任务路由
queueRouters map[string]RouterHandler // 队列路由
authRouters map[string]RouterHandler // 任务路由
}
// NewServer 初始一个tcp服务器对象
func NewServer(config *ServerConfig) (server *Server, err error) {
func NewServer(config *ServerConfig) (server *Server) {
server = new(Server)
server.ctx = gctx.New()
server.logger = g.Log("tcpServer")
baseErr := gerror.New("TCPServer start fail")
if config == nil {
err = gerror.New("config is nil")
server.logger.Fatal(server.ctx, gerror.Wrap(baseErr, "config is nil"))
return
}
if config.Addr == "" {
err = gerror.New("server address is not set")
server.logger.Fatal(server.ctx, gerror.Wrap(baseErr, "server address is not set"))
return
}
server = new(Server)
server.Ctx = gctx.New()
if config.Name == "" {
config.Name = simple.AppName(server.Ctx)
config.Name = simple.AppName(server.ctx)
}
server.addr = config.Addr
server.name = config.Name
server.ln = gtcp.NewServer(server.addr, server.accept, config.Name)
server.clients = make(map[string]*ClientConn)
server.closeFlag = false
server.Logger = g.Log("tcpServer")
server.msgGo = grpool.New(20)
server.rpc = NewRpc(server.Ctx, server.msgGo, server.Logger)
server.closeFlag = gtype.NewBool(false)
server.clients = make(map[string]*Conn)
server.taskGo = grpool.New(20)
server.msgParser = NewMsgParser(server.handleRoutineTask)
server.startCron()
return
}
// accept
func (server *Server) accept(conn *gtcp.Conn) {
defer func() {
server.mutexConns.Lock()
_ = conn.Close()
// 从登录列表中移除
delete(server.clients, conn.RemoteAddr().String())
server.mutexConns.Unlock()
tcpConn := NewConn(conn, server.logger, server.msgParser)
server.AddClient(tcpConn)
go func() {
if err := tcpConn.Run(); err != nil {
server.logger.Info(server.ctx, err)
}
// cleanup
tcpConn.Close()
server.RemoveClient(tcpConn)
}()
}
for {
msg, err := RecvPkg(conn)
if err != nil {
server.Logger.Debugf(server.Ctx, "RecvPkg err:%+v, client closed.", err)
break
}
client := server.getLoginConn(conn)
switch msg.Router {
case "ServerLogin": // 服务登录
// 初始化上下文
ctx := initCtx(gctx.New(), &Context{
Conn: conn,
})
server.doHandleRouterMsg(ctx, server.onServerLogin, msg.Data)
case "ServerHeartbeat": // 心跳
if client == nil {
server.Logger.Infof(server.Ctx, "conn not connected, ignore the heartbeat, msg:%+v", msg)
continue
}
// 初始化上下文
ctx := initCtx(gctx.New(), &Context{})
server.doHandleRouterMsg(ctx, server.onServerHeartbeat, msg.Data, client)
default: // 通用路由消息处理
if client == nil {
server.Logger.Warningf(server.Ctx, "conn is not logged in but sends a routing message. actively conn disconnect, msg:%+v", msg)
time.Sleep(time.Second)
conn.Close()
return
}
server.handleRouterMsg(msg, client)
}
// RemoveClient 移除客户端
func (server *Server) RemoveClient(conn *Conn) {
label := server.ClientLabel(conn.Conn)
if _, ok := server.clients[label]; ok {
server.mutexConns.Lock()
delete(server.clients, label)
server.mutexConns.Unlock()
}
}
// handleRouterMsg 处理路由消息
func (server *Server) handleRouterMsg(msg *Message, client *ClientConn) {
// 验证签名
in, err := VerifySign(msg.Data, client.Auth.AppId, client.Auth.SecretKey)
if err != nil {
server.Logger.Warningf(server.Ctx, "handleRouterMsg VerifySign err:%+v message: %+v", err, msg)
return
}
// 初始化上下文
ctx := initCtx(gctx.New(), &Context{
Conn: client.Conn,
Auth: client.Auth,
TraceID: in.TraceID,
})
// 响应rpc消息
if server.rpc.HandleMsg(ctx, msg.Data) {
return
}
handle := func(routers map[string]RouterHandler, group string) {
if routers == nil {
server.Logger.Debugf(server.Ctx, "handleRouterMsg route is not initialized %v message: %+v", group, msg)
return
}
f, ok := routers[msg.Router]
if !ok {
server.Logger.Debugf(server.Ctx, "handleRouterMsg invalid %v message: %+v", group, msg)
return
}
server.doHandleRouterMsg(ctx, f, msg.Data)
}
switch client.Auth.Group {
case consts.TCPClientGroupCron:
handle(server.cronRouters, client.Auth.Group)
case consts.TCPClientGroupQueue:
handle(server.queueRouters, client.Auth.Group)
case consts.TCPClientGroupAuth:
handle(server.authRouters, client.Auth.Group)
default:
server.Logger.Warningf(server.Ctx, "group is not registered: %+v", client.Auth.Group)
}
// AddClient 添加客户端
func (server *Server) AddClient(conn *Conn) {
server.mutexConns.Lock()
server.clients[server.ClientLabel(conn.Conn)] = conn
server.mutexConns.Unlock()
}
// doHandleRouterMsg 处理路由消息
func (server *Server) doHandleRouterMsg(ctx context.Context, fun RouterHandler, args ...interface{}) {
ctx, cancel := context.WithCancel(ctx)
err := server.msgGo.AddWithRecover(ctx,
func(ctx context.Context) {
fun(ctx, args...)
cancel()
},
func(ctx context.Context, err error) {
server.Logger.Warningf(ctx, "doHandleRouterMsg msgGo exec err:%+v", err)
cancel()
},
)
if err != nil {
server.Logger.Warningf(ctx, "doHandleRouterMsg msgGo Add err:%+v", err)
// AuthClient 认证客户端
func (server *Server) AuthClient(conn *Conn, auth *AuthMeta) {
label := server.ClientLabel(conn.Conn)
client, ok := server.clients[label]
if !ok {
server.logger.Debugf(server.ctx, "authClient client does not exist:%v", label)
return
}
client.Auth = auth
}
// getLoginConn 获取指定已登录的连接
func (server *Server) getLoginConn(conn *gtcp.Conn) *ClientConn {
client, ok := server.clients[conn.RemoteAddr().String()]
// ClientLabel 客户端标识
func (server *Server) ClientLabel(conn *gtcp.Conn) string {
return conn.RemoteAddr().String()
}
// GetClient 获取指定连接
func (server *Server) GetClient(conn *gtcp.Conn) *Conn {
client, ok := server.clients[server.ClientLabel(conn)]
if !ok {
return nil
}
return client
}
// getLoginConn 获取指定appid的所有连接
func (server *Server) getAppIdClients(appid string) (list []*ClientConn) {
// GetClients 获取所有连接
func (server *Server) GetClients() map[string]*Conn {
return server.clients
}
// GetClientById 通过连接ID获取连接
func (server *Server) GetClientById(id int64) *Conn {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
for _, v := range server.clients {
if v.Auth.AppId == appid {
if v.CID == id {
return v
}
}
return nil
}
// GetAppIdClients 获取指定appid的所有连接
func (server *Server) GetAppIdClients(appid string) (list []*Conn) {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
for _, v := range server.clients {
if v.Auth != nil && v.Auth.AppId == appid {
list = append(list, v)
}
}
@@ -216,70 +161,63 @@ func (server *Server) getAppIdClients(appid string) (list []*ClientConn) {
}
// GetGroupClients 获取指定分组的所有连接
func (server *Server) GetGroupClients(group string) (list []*ClientConn) {
func (server *Server) GetGroupClients(group string) (list []*Conn) {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
for _, v := range server.clients {
if v.Auth.Group == group {
if v.Auth != nil && v.Auth.Group == group {
list = append(list, v)
}
}
return
}
// RegisterAuthRouter 注册授权路由
func (server *Server) RegisterAuthRouter(routers map[string]RouterHandler) {
server.mutex.Lock()
defer server.mutex.Unlock()
if server.authRouters == nil {
server.authRouters = make(map[string]RouterHandler)
}
for i, router := range routers {
_, ok := server.authRouters[i]
if ok {
server.Logger.Debugf(server.Ctx, "server authRouters duplicate registration:%v", i)
continue
}
server.authRouters[i] = router
}
// GetAppIdOnline 获取指定appid的在线数量
func (server *Server) GetAppIdOnline(appid string) int {
return len(server.GetAppIdClients(appid))
}
// RegisterCronRouter 注册任务路由
func (server *Server) RegisterCronRouter(routers map[string]RouterHandler) {
server.mutex.Lock()
defer server.mutex.Unlock()
if server.cronRouters == nil {
server.cronRouters = make(map[string]RouterHandler)
}
for i, router := range routers {
_, ok := server.cronRouters[i]
if ok {
server.Logger.Debugf(server.Ctx, "server cronRouters duplicate registration:%v", i)
continue
}
server.cronRouters[i] = router
}
// GetAllOnline 获取所有在线数量
func (server *Server) GetAllOnline() int {
return len(server.clients)
}
// RegisterQueueRouter 注册队列路由
func (server *Server) RegisterQueueRouter(routers map[string]RouterHandler) {
server.mutex.Lock()
defer server.mutex.Unlock()
// GetAuthOnline 获取所有已登录认证在线数量
func (server *Server) GetAuthOnline() int {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
if server.queueRouters == nil {
server.queueRouters = make(map[string]RouterHandler)
}
for i, router := range routers {
_, ok := server.queueRouters[i]
if ok {
server.Logger.Debugf(server.Ctx, "server queueRouters duplicate registration:%v", i)
continue
online := 0
for _, v := range server.clients {
if v.Auth != nil {
online++
}
server.queueRouters[i] = router
}
return online
}
// RegisterRouter 注册路由
func (server *Server) RegisterRouter(routers ...interface{}) {
err := server.msgParser.RegisterRouter(routers...)
if err != nil {
server.logger.Fatal(server.ctx, err)
}
return
}
// RegisterRPCRouter 注册RPC路由
func (server *Server) RegisterRPCRouter(routers ...interface{}) {
err := server.msgParser.RegisterRPCRouter(routers...)
if err != nil {
server.logger.Fatal(server.ctx, err)
}
return
}
// RegisterInterceptor 注册拦截器
func (server *Server) RegisterInterceptor(interceptors ...Interceptor) {
server.msgParser.RegisterInterceptor(interceptors...)
}
// Listen 监听服务
@@ -291,77 +229,68 @@ func (server *Server) Listen() (err error) {
// Close 关闭服务
func (server *Server) Close() {
if server.closeFlag {
if server.closeFlag.Val() {
return
}
server.closeFlag = true
server.closeFlag.Set(true)
server.stopCron()
server.mutexConns.Lock()
for _, client := range server.clients {
_ = client.Conn.Close()
client.Conn.Close()
}
server.clients = nil
server.mutexConns.Unlock()
if server.ln != nil {
_ = server.ln.Close()
server.ln.Close()
}
server.wgLn.Wait()
}
// IsClose 服务是否关闭
func (server *Server) IsClose() bool {
return server.closeFlag
return server.closeFlag.Val()
}
// Write 向指定客户端发送消息
func (server *Server) Write(conn *gtcp.Conn, data interface{}) (err error) {
if server.closeFlag {
return gerror.New("service is down")
}
msgType := reflect.TypeOf(data)
if msgType == nil || msgType.Kind() != reflect.Ptr {
return gerror.Newf("json message pointer required: %+v", data)
}
msg := &Message{Router: msgType.Elem().Name(), Data: data}
return SendPkg(conn, msg)
}
// Send 发送消息
func (server *Server) Send(ctx context.Context, client *ClientConn, data interface{}) (err error) {
MsgPkg(data, client.Auth, gctx.CtxId(ctx))
return server.Write(client.Conn, data)
}
// Reply 回复消息
func (server *Server) Reply(ctx context.Context, data interface{}) (err error) {
user := GetCtx(ctx)
if user == nil {
err = gerror.New("获取回复用户信息失败")
return
}
MsgPkg(data, user.Auth, user.TraceID)
return server.Write(user.Conn, data)
}
// RpcRequest 向指定客户端发送消息并等待响应结果
func (server *Server) RpcRequest(ctx context.Context, client *ClientConn, data interface{}) (res interface{}, err error) {
var (
traceID = MsgPkg(data, client.Auth, gctx.CtxId(ctx))
key = server.rpc.GetCallId(client.Conn, traceID)
// handleRoutineTask 处理协程任务
func (server *Server) handleRoutineTask(ctx context.Context, task func()) {
ctx, cancel := context.WithCancel(ctx)
err := server.taskGo.AddWithRecover(ctx,
func(ctx context.Context) {
task()
cancel()
},
func(ctx context.Context, err error) {
server.logger.Warningf(ctx, "routineTask exec err:%+v", err)
cancel()
},
)
if traceID == "" {
err = gerror.New("traceID is required")
if err != nil {
server.logger.Warningf(ctx, "routineTask add err:%+v", err)
}
}
// GetRoutes 获取所有路由
func (server *Server) GetRoutes() (routes []RouteHandler) {
if server.msgParser.routers == nil {
return
}
return server.rpc.Request(key, func() {
_ = server.Write(client.Conn, data)
})
for _, v := range server.msgParser.routers {
routes = append(routes, *v)
}
return
}
// Request 向指定客户端发送消息并等待响应结果
func (server *Server) Request(ctx context.Context, client *Conn, data interface{}) (interface{}, error) {
return client.Request(ctx, data)
}
// RequestScan 向指定客户端发送消息并等待响应结果将结果保存在response中
func (server *Server) RequestScan(ctx context.Context, client *Conn, data, response interface{}) error {
return client.RequestScan(ctx, data, response)
}

View File

@@ -10,7 +10,6 @@ import (
"fmt"
"github.com/gogf/gf/v2/os/gcron"
"github.com/gogf/gf/v2/os/gtime"
"hotgo/internal/consts"
)
// getCronKey 生成服务端定时任务名称
@@ -28,32 +27,35 @@ func (server *Server) stopCron() {
// startCron 启动定时任务
func (server *Server) startCron() {
// 心跳超时检查
if gcron.Search(server.getCronKey(consts.TCPCronHeartbeatVerify)) == nil {
_, _ = gcron.AddSingleton(server.Ctx, "@every 300s", func(ctx context.Context) {
if server.clients == nil {
if gcron.Search(server.getCronKey(CronHeartbeatVerify)) == nil {
gcron.AddSingleton(server.ctx, "@every 300s", func(ctx context.Context) {
if server == nil || server.clients == nil {
return
}
for _, client := range server.clients {
if client.heartbeat < gtime.Timestamp()-consts.TCPHeartbeatTimeout {
_ = client.Conn.Close()
server.Logger.Debugf(server.Ctx, "client heartbeat timeout, close conn. auth:%+v", client.Auth)
if client.Heartbeat < gtime.Timestamp()-HeartbeatTimeout {
client.Conn.Close()
server.logger.Debugf(server.ctx, "client heartbeat timeout, close conn. auth:%+v", client.Auth)
}
}
}, server.getCronKey(consts.TCPCronHeartbeatVerify))
}, server.getCronKey(CronHeartbeatVerify))
}
// 认证检查
if gcron.Search(server.getCronKey(consts.TCPCronAuthVerify)) == nil {
_, _ = gcron.AddSingleton(server.Ctx, "@every 300s", func(ctx context.Context) {
if server.clients == nil {
if gcron.Search(server.getCronKey(CronAuthVerify)) == nil {
gcron.AddSingleton(server.ctx, "@every 300s", func(ctx context.Context) {
if server == nil || server.clients == nil {
return
}
for _, client := range server.clients {
if client.Auth == nil {
continue
}
if client.Auth.EndAt.Before(gtime.Now()) {
_ = client.Conn.Close()
server.Logger.Debugf(server.Ctx, "client auth expired, close conn. auth:%+v", client.Auth)
client.Conn.Close()
server.logger.Debugf(server.ctx, "client auth expired, close conn. auth:%+v", client.Auth)
}
}
}, server.getCronKey(consts.TCPCronAuthVerify))
}, server.getCronKey(CronAuthVerify))
}
}

View File

@@ -1,166 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"context"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/consts"
"hotgo/internal/model/entity"
"hotgo/internal/model/input/msgin"
"hotgo/utility/convert"
)
// onServerLogin 处理客户端登录
func (server *Server) onServerLogin(ctx context.Context, args ...interface{}) {
var (
in = new(msgin.ServerLogin)
user = GetCtx(ctx)
res = new(msgin.ResponseServerLogin)
models *entity.SysServeLicense
)
if err := gconv.Scan(args[0], &in); err != nil {
server.Logger.Warningf(ctx, "onServerLogin message Scan failed:%+v, args:%+v", err, args)
return
}
err := g.Model("sys_serve_license").Ctx(ctx).
Where("appid = ?", in.AppId).
Scan(&models)
if err != nil {
res.Code = 1
res.Message = err.Error()
_ = server.Write(user.Conn, res)
return
}
if models == nil {
res.Code = 2
res.Message = "授权信息不存在"
_ = server.Write(user.Conn, res)
return
}
// 验证签名
if _, err = VerifySign(in, models.Appid, models.SecretKey); err != nil {
res.Code = 3
res.Message = "签名错误,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
if models.Status != consts.StatusEnabled {
res.Code = 4
res.Message = "授权已禁用,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
if models.Group != in.Group {
res.Code = 5
res.Message = "你登录的授权分组未得到授权,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
if models.EndAt.Before(gtime.Now()) {
res.Code = 6
res.Message = "授权已过期,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
allowedIps := convert.IpFilterStrategy(models.AllowedIps)
if _, ok := allowedIps["*"]; !ok {
ip := gstr.StrTillEx(user.Conn.RemoteAddr().String(), ":")
if _, ok2 := allowedIps[ip]; !ok2 {
res.Code = 7
res.Message = "IP(" + ip + ")未授权,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
}
// 检查是否存在多地登录,如果连接超出上限,直接将所有已连接断开
clients := server.getAppIdClients(models.Appid)
online := len(clients) + 1
if online > models.OnlineLimit {
res2 := new(msgin.ResponseServerLogin)
res2.Code = 8
res2.Message = "授权登录端超出上限已进行记录。请立即终止操作。如有疑问请联系管理员"
for _, client := range clients {
_ = server.Write(client.Conn, res2)
_ = client.Conn.Close()
}
// 当前连接也踢掉
_ = server.Write(user.Conn, res2)
_ = user.Conn.Close()
return
}
server.mutexConns.Lock()
server.clients[user.Conn.RemoteAddr().String()] = &ClientConn{
Conn: user.Conn,
Auth: &AuthMeta{
Group: in.Group,
Name: in.Name,
AppId: in.AppId,
SecretKey: models.SecretKey,
EndAt: models.EndAt,
},
heartbeat: gtime.Timestamp(),
}
server.mutexConns.Unlock()
_, err = g.Model("sys_serve_license").Ctx(ctx).
Where("id = ?", models.Id).Data(g.Map{
"online": online,
"login_times": models.LoginTimes + 1,
"last_login_at": gtime.Now(),
"last_active_at": gtime.Now(),
"remote_addr": user.Conn.RemoteAddr().String(),
}).Update()
if err != nil {
server.Logger.Warningf(ctx, "onServerLogin Update err:%+v", err)
}
res.AppId = in.AppId
res.Code = consts.TCPMsgCodeSuccess
_ = server.Write(user.Conn, res)
}
// onServerHeartbeat 处理客户端心跳
func (server *Server) onServerHeartbeat(ctx context.Context, args ...interface{}) {
var (
in *msgin.ServerHeartbeat
res = new(msgin.ResponseServerHeartbeat)
)
if err := gconv.Scan(args[0], &in); err != nil {
server.Logger.Warningf(ctx, "onServerHeartbeat message Scan failed:%+v, args:%+v", err, args)
return
}
client := args[1].(*ClientConn)
client.heartbeat = gtime.Timestamp()
_, err := g.Model("sys_serve_license").Ctx(ctx).
Where("appid = ?", client.Auth.AppId).Data(g.Map{
"last_active_at": gtime.Now(),
}).Update()
if err != nil {
server.Logger.Warningf(ctx, "onServerHeartbeat Update err:%+v", err)
}
res.Code = consts.TCPMsgCodeSuccess
_ = server.Write(client.Conn, res)
}

View File

@@ -1,49 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/model/input/msgin"
)
type Sign interface {
SetSign(appId, secretKey string) *msgin.RpcMsg
SetTraceID(traceID string)
}
// PkgSign 打包签名
func PkgSign(data interface{}, appId, secretKey, traceID string) *msgin.RpcMsg {
if c, ok := data.(Sign); ok {
c.SetTraceID(traceID)
return c.SetSign(appId, secretKey)
}
return nil
}
// VerifySign 验证签名
func VerifySign(data interface{}, appId, secretKey string) (in *msgin.RpcMsg, err error) {
// 无密钥,无需签名
if secretKey == "" {
return
}
if err = gconv.Scan(data, &in); err != nil {
return
}
if appId != in.AppId {
err = gerror.New("appId invalid")
return
}
if in.Sign != in.GetSign(secretKey) {
err = gerror.New("sign invalid")
return
}
return
}

View File

@@ -0,0 +1,128 @@
package tcp_test
import (
"context"
"fmt"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/test/gtest"
"hotgo/internal/library/network/tcp"
"testing"
"time"
)
var T *testing.T // 声明一个全局的 *testing.T 变量
type TestMsgReq struct {
Name string `json:"name"`
}
type TestMsgRes struct {
tcp.ServerRes
}
type TestRPCMsgReq struct {
Name string `json:"name"`
}
type TestRPCMsgRes struct {
tcp.ServerRes
}
func onTestMsg(ctx context.Context, req *TestMsgReq) {
fmt.Printf("服务器收到消息 ==> onTestMsg:%+v\n", req)
conn := tcp.ConnFromCtx(ctx)
gtest.C(T, func(t *gtest.T) {
t.AssertNE(conn, nil)
})
res := new(TestMsgRes)
res.Message = fmt.Sprintf("你的名字:%v", req.Name)
conn.Send(ctx, res)
}
func onResponseTestMsg(ctx context.Context, req *TestMsgRes) {
fmt.Printf("客户端收到响应消息 ==> TestMsgRes:%+v\n", req)
err := req.GetError()
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}
func onTestRPCMsg(ctx context.Context, req *TestRPCMsgReq) (res *TestRPCMsgRes, err error) {
fmt.Printf("服务器收到消息 ==> onTestRPCMsg:%+v\n", req)
res = new(TestRPCMsgRes)
res.Message = fmt.Sprintf("你的名字:%v", req.Name)
return
}
func startTCPServer() {
serv := tcp.NewServer(&tcp.ServerConfig{
Name: "hotgo",
Addr: ":8002",
})
// 注册路由
serv.RegisterRouter(
onTestMsg,
)
// 注册RPC路由
serv.RegisterRPCRouter(
onTestRPCMsg,
)
// 服务监听
err := serv.Listen()
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}
// 一个基本的消息收发
func TestSendMsg(t *testing.T) {
T = t
go startTCPServer()
ctx := gctx.New()
client := tcp.NewClient(&tcp.ClientConfig{
Addr: "127.0.0.1:8002",
})
// 注册路由
client.RegisterRouter(
onResponseTestMsg,
)
go func() {
err := client.Start()
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}()
// 确保服务都启动完成
time.Sleep(time.Second * 1)
// 拿到客户端的连接
conn := client.Conn()
gtest.C(T, func(t *gtest.T) {
t.AssertNE(conn, nil)
})
// 向服务器发送tcp消息不会阻塞程序执行
err := conn.Send(ctx, &TestMsgReq{Name: "Tom"})
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
// 向服务器发送rpc消息会等待服务器响应结果直到拿到结果或响应超时才会继续
var res TestRPCMsgRes
if err = conn.RequestScan(ctx, &TestRPCMsgReq{Name: "Tony"}, &res); err != nil {
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}
fmt.Printf("客户端收到RPC消息响应 ==> TestRPCMsgRes:%+v\n", res)
time.Sleep(time.Second * 1)
}

View File

@@ -12,7 +12,7 @@ import (
// 异步回调
type NotifyCallFunc func(ctx context.Context, pay payin.NotifyCallFuncInp) (err error)
type NotifyCallFunc func(ctx context.Context, pay *payin.NotifyCallFuncInp) (err error)
var (
notifyCall = make(map[string]NotifyCallFunc)
@@ -37,7 +37,7 @@ func RegisterNotifyCallMap(calls map[string]NotifyCallFunc) {
}
// NotifyCall 执行订单分组的异步回调
func NotifyCall(ctx context.Context, in payin.NotifyCallFuncInp) {
func NotifyCall(ctx context.Context, in *payin.NotifyCallFuncInp) {
f, ok := notifyCall[in.Pay.OrderGroup]
if ok {
ctx = contexts.Detach(ctx)

View File

@@ -1,120 +0,0 @@
package aliyun
import (
"context"
"fmt"
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
dysmsapi20170525 "github.com/alibabacloud-go/dysmsapi-20170525/v3/client"
util "github.com/alibabacloud-go/tea-utils/v2/service"
"github.com/alibabacloud-go/tea/tea"
"github.com/gogf/gf/v2/frame/g"
"hotgo/internal/model"
"hotgo/internal/model/input/sysin"
"hotgo/internal/service"
)
var (
Handle = aliYun{}
)
type aliYun struct{}
// SendCode 发送验证码
func (d *aliYun) SendCode(ctx context.Context, in sysin.SendCodeInp, config *model.SmsConfig) (err error) {
if config == nil {
config, err = service.SysConfig().GetSms(ctx)
if err != nil {
return err
}
}
client, err := CreateClient(tea.String(config.AliYunAccessKeyID), tea.String(config.AliYunAccessKeySecret))
if err != nil {
return err
}
sendSmsRequest := &dysmsapi20170525.SendSmsRequest{
PhoneNumbers: tea.String(in.Mobile),
SignName: tea.String(config.AliYunSign),
TemplateCode: tea.String(in.Template),
TemplateParam: tea.String(fmt.Sprintf("{\"code\":\"%v\"}", in.Code)),
}
tryErr := func() (_e error) {
defer func() {
if r := tea.Recover(recover()); r != nil {
_e = r
}
}()
// 复制代码运行请自行打印 API 的返回值
response, err := client.SendSmsWithOptions(sendSmsRequest, &util.RuntimeOptions{})
if err != nil {
return err
}
g.Log().Debugf(ctx, "aliyun.sendCode response:%+v", response.GoString())
return nil
}()
return tryErr
}
// CreateClient 使用AK&SK初始化账号Client
func CreateClient(accessKeyId *string, accessKeySecret *string) (_result *dysmsapi20170525.Client, _err error) {
config := &openapi.Config{
// 必填,您的 AccessKey ID
AccessKeyId: accessKeyId,
// 必填,您的 AccessKey Secret
AccessKeySecret: accessKeySecret,
}
// 访问的域名
config.Endpoint = tea.String("dysmsapi.aliyuncs.com")
_result, _err = dysmsapi20170525.NewClient(config)
return _result, _err
}
func TestSend(accessKeyId string, accessKeySecret string) (_err error) {
// 工程代码泄露可能会导致AccessKey泄露并威胁账号下所有资源的安全性。以下代码示例仅供参考建议使用更安全的 STS 方式更多鉴权访问方式请参见https://help.aliyun.com/document_detail/378661.html
client, _err := CreateClient(tea.String(accessKeyId), tea.String(accessKeySecret))
if _err != nil {
return _err
}
sendSmsRequest := &dysmsapi20170525.SendSmsRequest{
PhoneNumbers: tea.String("15303830571"),
SignName: tea.String("布帆云"),
TemplateCode: tea.String("SMS_198921686"),
TemplateParam: tea.String("{\"code\":\"1234\"}"),
}
runtime := &util.RuntimeOptions{}
tryErr := func() (_e error) {
defer func() {
if r := tea.Recover(recover()); r != nil {
_e = r
}
}()
// 复制代码运行请自行打印 API 的返回值
_, _err = client.SendSmsWithOptions(sendSmsRequest, runtime)
if _err != nil {
return _err
}
return nil
}()
if tryErr != nil {
var err = &tea.SDKError{}
if _t, ok := tryErr.(*tea.SDKError); ok {
err = _t
} else {
err.Message = tea.String(tryErr.Error())
}
// 如有需要,请打印 error
_, _err = util.AssertAsString(err.Message)
if _err != nil {
return _err
}
}
return _err
}

View File

@@ -0,0 +1,22 @@
package sms
import (
"context"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"hotgo/internal/model"
)
var config *model.SmsConfig
func SetConfig(c *model.SmsConfig) {
config = c
}
func GetConfig() *model.SmsConfig {
return config
}
func GetModel(ctx context.Context) *gdb.Model {
return g.Model("sys_sms_log").Ctx(ctx)
}

View File

@@ -4,15 +4,12 @@ import (
"context"
"fmt"
"hotgo/internal/consts"
"hotgo/internal/library/sms/aliyun"
"hotgo/internal/library/sms/tencent"
"hotgo/internal/model"
"hotgo/internal/model/input/sysin"
)
// Drive 短信驱动
type Drive interface {
SendCode(ctx context.Context, in sysin.SendCodeInp, config *model.SmsConfig) (err error)
SendCode(ctx context.Context, in *sysin.SendCodeInp) (err error)
}
func New(name ...string) Drive {
@@ -27,9 +24,9 @@ func New(name ...string) Drive {
switch instanceName {
case consts.SmsDriveAliYun:
drive = &aliyun.Handle
drive = &AliYunDrive{}
case consts.SmsDriveTencent:
drive = &tencent.Handle
drive = &TencentDrive{}
default:
panic(fmt.Sprintf("暂不支持短信驱动:%v", instanceName))
}

View File

@@ -0,0 +1,60 @@
package sms
import (
"context"
"fmt"
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
dysmsapi20170525 "github.com/alibabacloud-go/dysmsapi-20170525/v3/client"
util "github.com/alibabacloud-go/tea-utils/v2/service"
"github.com/alibabacloud-go/tea/tea"
"github.com/gogf/gf/v2/frame/g"
"hotgo/internal/model/input/sysin"
)
type AliYunDrive struct{}
// SendCode 发送验证码
func (d *AliYunDrive) SendCode(ctx context.Context, in *sysin.SendCodeInp) (err error) {
client, err := d.CreateClient(tea.String(config.AliYunAccessKeyID), tea.String(config.AliYunAccessKeySecret))
if err != nil {
return err
}
sendSmsRequest := &dysmsapi20170525.SendSmsRequest{
PhoneNumbers: tea.String(in.Mobile),
SignName: tea.String(config.AliYunSign),
TemplateCode: tea.String(in.Template),
TemplateParam: tea.String(fmt.Sprintf("{\"code\":\"%v\"}", in.Code)),
}
return func() (_e error) {
defer func() {
if r := tea.Recover(recover()); r != nil {
_e = r
}
}()
// 复制代码运行请自行打印 API 的返回值
response, err := client.SendSmsWithOptions(sendSmsRequest, &util.RuntimeOptions{})
if err != nil {
return err
}
g.Log().Debugf(ctx, "aliyun.sendCode response:%+v", response.GoString())
return nil
}()
}
// CreateClient 使用AK&SK初始化账号Client
func (d *AliYunDrive) CreateClient(accessKeyId *string, accessKeySecret *string) (result *dysmsapi20170525.Client, err error) {
conf := &openapi.Config{
// 必填,您的 AccessKey ID
AccessKeyId: accessKeyId,
// 必填,您的 AccessKey Secret
AccessKeySecret: accessKeySecret,
}
// 访问的域名
conf.Endpoint = tea.String("dysmsapi.aliyuncs.com")
return dysmsapi20170525.NewClient(conf)
}

View File

@@ -1,11 +1,10 @@
package tencent
package sms
import (
"context"
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/frame/g"
"hotgo/internal/model"
"hotgo/internal/model/input/sysin"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
@@ -14,14 +13,10 @@ import (
sms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111" // 引入sms
)
var (
Handle = tencent{}
)
type tencent struct{}
type TencentDrive struct{}
// SendCode 发送验证码
func (d *tencent) SendCode(ctx context.Context, in sysin.SendCodeInp, config *model.SmsConfig) (err error) {
func (d *TencentDrive) SendCode(ctx context.Context, in *sysin.SendCodeInp) (err error) {
/* 必要步骤
* 实例化一个认证对象入参需要传入腾讯云账户密钥对secretIdsecretKey
* 这里采用的是从环境变量读取的方式需要在环境变量中先设置这两个值

View File

@@ -18,13 +18,16 @@ import (
// 文件归属分类
const (
KindImg = "images" // 图片
KindDoc = "document" // 文档
KindAudio = "audio" // 音频
KindVideo = "video" // 视频
KindOther = "other" // 其他
KindImg = "image" // 图片
KindDoc = "doc" // 文档
KindAudio = "audio" // 音频
KindVideo = "video" // 视频
KindZip = "zip" // 压缩包
KindOther = "other" // 其他
)
var KindSlice = []string{KindImg, KindDoc, KindAudio, KindVideo, KindZip, KindOther}
var (
// 图片类型
imgType = g.MapStrStr{
@@ -45,12 +48,33 @@ var (
// 文档类型
docType = g.MapStrStr{
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"xls": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"dot": "application/msword",
"xls": "application/vnd.ms-excel",
"xlt": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"pdf": "application/pdf",
"txt": "text/plain",
"csv": "text/csv",
"html": "text/html",
"xml": "text/xml",
"pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12",
"pot": "application/vnd.ms-powerpoint",
"wpd": "application/wordperfect",
"md": "text/markdown",
"json": "application/json",
"yaml": "application/x-yaml",
"markdown": "text/markdown",
"asciidoc": "text/asciidoc",
"xsl": "application/xml",
"wps": "application/vnd.ms-works",
"sxi": "application/vnd.sun.xml.impress",
"sti": "application/vnd.sun.xml.impress.template",
"odp": "application/vnd.oasis.opendocument.presentation-template",
}
// 音频类型
@@ -79,6 +103,23 @@ var (
"flv": "video/x-flv",
"3gp": "video/3gpp",
}
// 压缩包
zipType = g.MapStrStr{
"zip": "application/zip",
"rar": "application/x-rar-compressed",
"tar": "application/x-tar",
"gz": "application/gzip",
"7z": "application/octet-stream",
"tar.gz": "application/octet-stream",
}
// 其他
otherType = g.MapStrStr{
"dwf": "model/vnd.dwf",
"ics": "text/calendar",
"vcard": "text/vcard",
}
)
// IsImgType 判断是否为图片
@@ -105,16 +146,17 @@ func IsVideoType(ext string) bool {
return ok
}
// GetImgType 获取图片类型
func GetImgType(ext string) (string, error) {
if mime, ok := imgType[ext]; ok {
return mime, nil
}
return "", gerror.New("Invalid image type")
// IsZipType 判断是否为压缩包
func IsZipType(ext string) bool {
_, ok := zipType[ext]
return ok
}
// GetFileType 获取文件类型
func GetFileType(ext string) (string, error) {
// GetFileMimeType 获取文件扩展类型
// 如果文件类型没有加入系统映射类型,默认认为不是合法的文件类型。建议将常用的上传文件类型加入映射关系。
// 当然你也可以不做限制,可以上传任意文件。但需要谨慎处理和设置相应的安全措施。
// 获取任意扩展名的扩展类型mime.TypeByExtension(".xls")
func GetFileMimeType(ext string) (string, error) {
if mime, ok := imgType[ext]; ok {
return mime, nil
}
@@ -127,10 +169,16 @@ func GetFileType(ext string) (string, error) {
if mime, ok := videoType[ext]; ok {
return mime, nil
}
if mime, ok := zipType[ext]; ok {
return mime, nil
}
if mime, ok := otherType[ext]; ok {
return mime, nil
}
return "", gerror.Newf("Invalid file type:%v", ext)
}
// GetFileKind 获取文件所属分类
// GetFileKind 获取文件上传类型
func GetFileKind(ext string) string {
if _, ok := imgType[ext]; ok {
return KindImg
@@ -144,12 +192,15 @@ func GetFileKind(ext string) string {
if _, ok := videoType[ext]; ok {
return KindVideo
}
if _, ok := zipType[ext]; ok {
return KindZip
}
return KindOther
}
// Ext 获取文件后缀
func Ext(baseName string) string {
return gstr.StrEx(path.Ext(baseName), ".")
return gstr.ToLower(gstr.StrEx(path.Ext(baseName), "."))
}
// UploadFileByte 获取上传文件的byte

View File

@@ -4,9 +4,9 @@ package storager
type FileMeta struct {
Filename string // 文件名称
Size int64 // 文件大小
Kind string // 文件所属分类
MetaType string // 文件类型
Kind string // 文件上传类型
MimeType string // 文件扩展类型
NaiveType string // NaiveUI类型
Ext string // 文件后缀
Ext string // 文件扩展
Md5 string // 文件hash
}

View File

@@ -52,7 +52,7 @@ func New(name ...string) UploadDrive {
}
// DoUpload 上传入口
func DoUpload(ctx context.Context, file *ghttp.UploadFile, typ int) (result *entity.SysAttachment, err error) {
func DoUpload(ctx context.Context, typ string, file *ghttp.UploadFile) (result *entity.SysAttachment, err error) {
if file == nil {
err = gerror.New("文件必须!")
return
@@ -63,17 +63,12 @@ func DoUpload(ctx context.Context, file *ghttp.UploadFile, typ int) (result *ent
return
}
if _, err = GetFileType(meta.Ext); err != nil {
if _, err = GetFileMimeType(meta.Ext); err != nil {
return
}
switch typ {
case consts.UploadTypeFile:
if config.FileSize > 0 && meta.Size > config.FileSize*1024*1024 {
err = gerror.Newf("文件大小不能超过%vMB", config.FileSize)
return
}
case consts.UploadTypeImage:
case KindImg:
if !IsImgType(meta.Ext) {
err = gerror.New("上传的文件不是图片")
return
@@ -82,24 +77,34 @@ func DoUpload(ctx context.Context, file *ghttp.UploadFile, typ int) (result *ent
err = gerror.Newf("图片大小不能超过%vMB", config.ImageSize)
return
}
case consts.UploadTypeDoc:
case KindDoc:
if !IsDocType(meta.Ext) {
err = gerror.New("上传的文件不是文档")
return
}
case consts.UploadTypeAudio:
case KindAudio:
if !IsAudioType(meta.Ext) {
err = gerror.New("上传的文件不是音频")
return
}
case consts.UploadTypeVideo:
case KindVideo:
if !IsVideoType(meta.Ext) {
err = gerror.New("上传的文件不是视频")
return
}
case KindZip:
if !IsZipType(meta.Ext) {
err = gerror.New("上传的文件不是压缩文件")
return
}
case KindOther:
fallthrough
default:
err = gerror.Newf("无效的上传类型:%v", typ)
return
// 默认为通用的文件上传
if config.FileSize > 0 && meta.Size > config.FileSize*1024*1024 {
err = gerror.Newf("文件大小不能超过%vMB", config.FileSize)
return
}
}
result, err = hasFile(ctx, meta.Md5)
@@ -149,7 +154,7 @@ func GetFileMeta(file *ghttp.UploadFile) (meta *FileMeta, err error) {
meta.Size = file.Size
meta.Ext = Ext(file.Filename)
meta.Kind = GetFileKind(meta.Ext)
meta.MetaType, err = GetFileType(meta.Ext)
meta.MimeType, err = GetFileMimeType(meta.Ext)
if err != nil {
return
}
@@ -185,7 +190,7 @@ func write(ctx context.Context, meta *FileMeta, fullPath string) (models *entity
FileUrl: fullPath,
Name: meta.Filename,
Kind: meta.Kind,
MetaType: meta.MetaType,
MimeType: meta.MimeType,
NaiveType: meta.NaiveType,
Ext: meta.Ext,
Md5: meta.Md5,

View File

@@ -36,7 +36,7 @@ func (d *CosDrive) Upload(ctx context.Context, file *ghttp.UploadFile) (fullPath
},
})
fullPath = GenFullPath(config.UCloudPath, gfile.Ext(file.Filename))
fullPath = GenFullPath(config.CosPath, gfile.Ext(file.Filename))
_, err = client.Object.Put(ctx, fullPath, f2, nil)
return
}

View File

@@ -36,7 +36,7 @@ func (d *OssDrive) Upload(ctx context.Context, file *ghttp.UploadFile) (fullPath
return
}
fullPath = GenFullPath(config.UCloudPath, gfile.Ext(file.Filename))
fullPath = GenFullPath(config.OssPath, gfile.Ext(file.Filename))
err = bucket.PutObject(fullPath, f2)
return
}

Some files were not shown because too many files have changed in this diff Show More