This commit is contained in:
孟帅
2023-07-20 18:01:10 +08:00
parent 9113fc5297
commit 373d9627fb
492 changed files with 12170 additions and 6982 deletions

View File

@@ -7,19 +7,33 @@ package tcp
import (
"context"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/os/gtime"
"hotgo/utility/simple"
"reflect"
"sync"
"time"
)
// Client tcp客户端
type Client struct {
ctx context.Context // 上下文
conn *Conn // 连接对象
config *ClientConfig // 配置
msgParser *MsgParser // 消息处理器
logger *glog.Logger // 日志处理器
isLogin *gtype.Bool // 是否已登录
taskGo *grpool.Pool // 任务协程池
closeFlag *gtype.Bool // 关闭标签,关闭以后可以重连
stopFlag *gtype.Bool // 停止标签,停止以后不能重连
wg sync.WaitGroup // 流程控制
mutex sync.Mutex // 状态锁
}
// ClientConfig 客户端配置
type ClientConfig struct {
Addr string // 连接地址
@@ -33,86 +47,56 @@ type ClientConfig struct {
CloseEvent CallbackEvent // 连接关闭事件
}
// Client 客户端
type Client struct {
Ctx context.Context // 上下文
Logger *glog.Logger // 日志处理器
IsLogin bool // 是否已登录
addr string // 连接地址
auth *AuthMeta // 认证元数据
rpc *Rpc // rpc协议支持
timeout time.Duration // 连接超时时间
connectInterval time.Duration // 重连时间间隔
maxConnectCount uint // 最大重连次数0不限次数
connectCount uint // 已重连次数
autoReconnect bool // 是否开启自动重连
loginEvent CallbackEvent // 登录成功事件
closeEvent CallbackEvent // 连接关闭事件
sync.Mutex // 状态锁
heartbeat int64 // 心跳
msgGo *grpool.Pool // 消息处理协程池
routers map[string]RouterHandler // 已注册的路由
conn *gtcp.Conn // 连接对象
wg sync.WaitGroup // 状态控制
closeFlag bool // 关闭标签,关闭以后可以重连
stopFlag bool // 停止标签,停止以后不能重连
}
// CallbackEvent 回调事件
type CallbackEvent func()
// NewClient 初始化一个tcp客户端
func NewClient(config *ClientConfig) (client *Client, err error) {
func NewClient(config *ClientConfig) (client *Client) {
client = new(Client)
client.ctx = gctx.New()
client.logger = g.Log("tcpClient")
baseErr := gerror.New("NewClient fail")
if config == nil {
err = gerror.New("config is nil")
client.logger.Fatal(client.ctx, gerror.Wrap(baseErr, "config is nil"))
return
}
if config.Addr == "" {
err = gerror.New("client address is not set")
client.logger.Fatal(client.ctx, gerror.Wrap(baseErr, "client address is not set"))
return
}
if config.Auth == nil {
err = gerror.New("client auth cannot be empty")
return
}
if config.Auth.Group == "" || config.Auth.Name == "" {
err = gerror.New("Auth.Group or Auth.Group is nil")
return
}
client.Ctx = gctx.New()
client.autoReconnect = true
client.addr = config.Addr
client.auth = config.Auth
client.loginEvent = config.LoginEvent
client.closeEvent = config.CloseEvent
client.Logger = g.Log("tcpClient")
client.config = config
if config.ConnectInterval <= 0 {
client.connectInterval = 5 * time.Second
client.config.ConnectInterval = 5 * time.Second
} else {
client.connectInterval = config.ConnectInterval
client.config.ConnectInterval = config.ConnectInterval
}
if config.Timeout <= 0 {
client.timeout = 10 * time.Second
client.config.Timeout = 10 * time.Second
} else {
client.timeout = config.Timeout
client.config.Timeout = config.Timeout
}
client.msgGo = grpool.New(5)
client.rpc = NewRpc(client.Ctx, client.msgGo, client.Logger)
client.isLogin = gtype.NewBool(false)
client.closeFlag = gtype.NewBool(false)
client.stopFlag = gtype.NewBool(false)
client.taskGo = grpool.New(5)
client.msgParser = NewMsgParser(client.handleRoutineTask)
client.registerDefaultRouter()
return
}
// Start 启动tcp连接
func (client *Client) Start() (err error) {
client.Lock()
defer client.Unlock()
client.mutex.Lock()
defer client.mutex.Unlock()
if client.stopFlag {
if client.stopFlag.Val() {
err = gerror.New("client is stop")
return
}
@@ -121,64 +105,66 @@ func (client *Client) Start() (err error) {
return gerror.New("client is running")
}
client.IsLogin = false
client.connectCount = 0
client.closeFlag = false
client.stopFlag = false
client.isLogin.Set(false)
client.config.ConnectCount = 0
client.closeFlag.Set(false)
client.stopFlag.Set(false)
client.wg.Add(1)
simple.SafeGo(client.Ctx, func(ctx context.Context) {
simple.SafeGo(client.ctx, func(ctx context.Context) {
client.connect()
})
return
}
// registerDefaultRouter 注册默认路由
func (client *Client) registerDefaultRouter() {
var routers = []interface{}{
client.onResponseServerLogin, // 服务登录
client.onResponseServerHeartbeat, // 心跳
}
client.RegisterRouter(routers...)
}
// RegisterRouter 注册路由
func (client *Client) RegisterRouter(routers map[string]RouterHandler) (err error) {
if client.conn != nil {
return gerror.New("client is running")
func (client *Client) RegisterRouter(routers ...interface{}) {
err := client.msgParser.RegisterRouter(routers...)
if err != nil {
client.logger.Fatal(client.ctx, err)
}
}
client.Lock()
defer client.Unlock()
if client.routers == nil {
client.routers = make(map[string]RouterHandler)
// 默认路由
client.routers = map[string]RouterHandler{
"ResponseServerHeartbeat": client.onResponseServerHeartbeat,
"ResponseServerLogin": client.onResponseServerLogin,
}
// RegisterRPCRouter 注册RPC路由
func (client *Client) RegisterRPCRouter(routers ...interface{}) {
err := client.msgParser.RegisterRPCRouter(routers...)
if err != nil {
client.logger.Fatal(client.ctx, err)
}
}
for i, router := range routers {
_, ok := client.routers[i]
if ok {
return gerror.Newf("client route duplicate registration:%v", i)
}
client.routers[i] = router
}
return
// RegisterInterceptor 注册拦截器
func (client *Client) RegisterInterceptor(interceptors ...Interceptor) {
client.msgParser.RegisterInterceptor(interceptors...)
}
// dial
func (client *Client) dial() *gtcp.Conn {
for {
conn, err := gtcp.NewConn(client.addr, client.timeout)
if err == nil || client.closeFlag {
conn, err := gtcp.NewConn(client.config.Addr, client.config.Timeout)
if err == nil || client.closeFlag.Val() {
return conn
}
if client.maxConnectCount > 0 {
if client.connectCount < client.maxConnectCount {
client.connectCount += 1
if client.config.MaxConnectCount > 0 {
if client.config.ConnectCount < client.config.MaxConnectCount {
client.config.ConnectCount += 1
} else {
return nil
}
}
client.Logger.Debugf(client.Ctx, "connect to %v error: %v", client.addr, err)
time.Sleep(client.connectInterval)
client.logger.Debugf(client.ctx, "connect to %v error: %v", client.config.Addr, err)
time.Sleep(client.config.ConnectInterval)
continue
}
}
@@ -190,26 +176,24 @@ func (client *Client) connect() {
reconnect:
conn := client.dial()
if conn == nil {
if !client.stopFlag {
client.Logger.Debugf(client.Ctx, "client dial failed")
if !client.stopFlag.Val() {
client.logger.Debugf(client.ctx, "client dial failed")
}
return
}
client.Lock()
if client.closeFlag {
client.Unlock()
_ = conn.Close()
client.Logger.Debugf(client.Ctx, "client connect but closeFlag is true")
client.mutex.Lock()
if client.closeFlag.Val() {
client.mutex.Unlock()
conn.Close()
client.logger.Debugf(client.ctx, "client connect but closeFlag is true")
return
}
client.conn = conn
client.connectCount = 0
client.heartbeat = gtime.Timestamp()
client.conn = NewConn(conn, client.logger, client.msgParser)
client.config.ConnectCount = 0
client.read()
client.Unlock()
client.mutex.Unlock()
client.serverLogin()
client.startCron()
@@ -217,102 +201,38 @@ reconnect:
// read
func (client *Client) read() {
simple.SafeGo(client.Ctx, func(ctx context.Context) {
go func() {
defer func() {
client.Close()
client.Logger.Debugf(client.Ctx, "client are about to be reconnected..")
time.Sleep(client.connectInterval)
_ = client.Start()
client.logger.Debugf(client.ctx, "client are about to be reconnected..")
time.Sleep(client.config.ConnectInterval)
client.Start()
}()
for {
if client.conn == nil {
client.Logger.Debugf(client.Ctx, "client client.conn is nil, server closed")
break
}
msg, err := RecvPkg(client.conn)
if err != nil {
client.Logger.Debugf(client.Ctx, "client RecvPkg err:%+v, server closed", err)
break
}
if client.routers == nil {
client.Logger.Debugf(client.Ctx, "client RecvPkg routers is nil")
break
}
if msg == nil {
client.Logger.Debugf(client.Ctx, "client RecvPkg msg is nil")
break
}
f, ok := client.routers[msg.Router]
if !ok {
client.Logger.Debugf(client.Ctx, "client RecvPkg invalid message: %+v", msg)
continue
}
switch msg.Router {
case "ResponseServerLogin", "ResponseServerHeartbeat": // 服务登录、心跳无需验证签名
client.doHandleRouterMsg(initCtx(gctx.New(), &Context{}), f, msg.Data)
default: // 通用路由消息处理
in, err := VerifySign(msg.Data, client.auth.AppId, client.auth.SecretKey)
if err != nil {
client.Logger.Warningf(client.Ctx, "client read VerifySign err:%+v message: %+v", err, msg)
continue
}
ctx := initCtx(gctx.New(), &Context{
Conn: client.conn,
Auth: client.auth,
TraceID: in.TraceID,
})
// 响应rpc消息
if client.rpc.HandleMsg(ctx, msg.Data) {
return
}
client.doHandleRouterMsg(ctx, f, msg.Data)
}
if err := client.conn.Run(); err != nil {
client.logger.Debug(client.ctx, err)
}
})
}()
}
// Close 关闭同服务器的链接
func (client *Client) Close() {
client.Lock()
defer client.Unlock()
client.mutex.Lock()
defer client.mutex.Unlock()
client.IsLogin = false
client.closeFlag = true
client.isLogin.Set(false)
client.closeFlag.Set(true)
if client.conn != nil {
client.conn.Close()
client.conn = nil
}
if client.closeEvent != nil {
client.closeEvent()
if client.config.CloseEvent != nil {
client.config.CloseEvent()
}
client.wg.Wait()
}
// Stop 停止服务
func (client *Client) Stop() {
if client.stopFlag {
return
}
client.stopFlag = true
client.stopCron()
client.Close()
}
// IsStop 是否已停止
func (client *Client) IsStop() bool {
return client.stopFlag
}
// Destroy 销毁当前连接
func (client *Client) Destroy() {
client.stopCron()
@@ -322,80 +242,62 @@ func (client *Client) Destroy() {
}
}
// Write
func (client *Client) Write(data interface{}) error {
client.Lock()
defer client.Unlock()
if client.conn == nil {
return gerror.New("client conn is nil")
// Stop 停止服务
func (client *Client) Stop() {
if client.stopFlag.Val() {
return
}
client.stopFlag.Set(true)
client.stopCron()
client.Close()
}
if client.closeFlag {
return gerror.New("client conn is closed")
}
// IsStop 是否已停止
func (client *Client) IsStop() bool {
return client.stopFlag.Val()
}
if data == nil {
return gerror.New("client Write message is nil")
}
// IsLogin 是否已登录成功
func (client *Client) IsLogin() bool {
return client.isLogin.Val()
}
msgType := reflect.TypeOf(data)
if msgType == nil || msgType.Kind() != reflect.Ptr {
return gerror.Newf("client json message pointer required: %+v", data)
func (client *Client) handleRoutineTask(ctx context.Context, task func()) {
ctx, cancel := context.WithCancel(ctx)
err := client.taskGo.AddWithRecover(ctx,
func(ctx context.Context) {
task()
cancel()
},
func(ctx context.Context, err error) {
client.logger.Warningf(ctx, "routineTask exec err:%+v", err)
cancel()
},
)
if err != nil {
client.logger.Warningf(ctx, "routineTask add err:%+v", err)
}
msg := &Message{Router: msgType.Elem().Name(), Data: data}
return SendPkg(client.conn, msg)
}
// Conn 获取当前连接
func (client *Client) Conn() *Conn {
return client.conn
}
// Send 发送消息
func (client *Client) Send(ctx context.Context, data interface{}) error {
MsgPkg(data, client.auth, gctx.CtxId(ctx))
return client.Write(data)
}
// Reply 回复消息
func (client *Client) Reply(ctx context.Context, data interface{}) (err error) {
user := GetCtx(ctx)
if user == nil {
err = gerror.New("获取回复用户信息失败")
return
if client.conn == nil {
return gerror.New("conn is nil")
}
MsgPkg(data, client.auth, user.TraceID)
return client.Write(data)
return client.conn.Send(ctx, data)
}
// RpcRequest 发送消息并等待响应结果
func (client *Client) RpcRequest(ctx context.Context, data interface{}) (res interface{}, err error) {
var (
traceID = MsgPkg(data, client.auth, gctx.CtxId(ctx))
key = client.rpc.GetCallId(client.conn, traceID)
)
if traceID == "" {
err = gerror.New("traceID is required")
return
}
return client.rpc.Request(key, func() {
_ = client.Write(data)
})
// Request 发送消息并等待响应结果
func (client *Client) Request(ctx context.Context, data interface{}) (interface{}, error) {
return client.conn.Request(ctx, data)
}
// doHandleRouterMsg 处理路由消息
func (client *Client) doHandleRouterMsg(ctx context.Context, fun RouterHandler, args ...interface{}) {
ctx, cancel := context.WithCancel(ctx)
err := client.msgGo.AddWithRecover(ctx,
func(ctx context.Context) {
fun(ctx, args...)
cancel()
},
func(ctx context.Context, err error) {
client.Logger.Warningf(ctx, "doHandleRouterMsg msgGo exec err:%+v", err)
cancel()
},
)
if err != nil {
client.Logger.Warningf(ctx, "doHandleRouterMsg msgGo Add err:%+v", err)
return
}
// RequestScan 发送消息并等待响应结果将结果保存在response中
func (client *Client) RequestScan(ctx context.Context, data, response interface{}) error {
return client.conn.RequestScan(ctx, data, response)
}

View File

@@ -10,12 +10,11 @@ import (
"fmt"
"github.com/gogf/gf/v2/os/gcron"
"github.com/gogf/gf/v2/os/gtime"
"hotgo/internal/consts"
)
// getCronKey 生成客户端定时任务名称
func (client *Client) getCronKey(s string) string {
return fmt.Sprintf("tcp.client_%s_%s:%s", s, client.auth.Group, client.auth.Name)
return fmt.Sprintf("tcp.client_%s:%s", s, client.conn.LocalAddr().String())
}
// stopCron 停止定时任务
@@ -28,19 +27,22 @@ func (client *Client) stopCron() {
// startCron 启动定时任务
func (client *Client) startCron() {
// 心跳超时检查
if gcron.Search(client.getCronKey(consts.TCPCronHeartbeatVerify)) == nil {
_, _ = gcron.AddSingleton(client.Ctx, "@every 600s", func(ctx context.Context) {
if client.heartbeat < gtime.Timestamp()-consts.TCPHeartbeatTimeout {
client.Logger.Debugf(client.Ctx, "client heartbeat timeout, about to reconnect..")
if gcron.Search(client.getCronKey(CronHeartbeatVerify)) == nil {
_, _ = gcron.AddSingleton(client.ctx, "@every 600s", func(ctx context.Context) {
if client == nil || client.conn == nil {
return
}
if client.conn.Heartbeat < gtime.Timestamp()-HeartbeatTimeout {
client.logger.Debugf(client.ctx, "client heartbeat timeout, about to reconnect..")
client.Destroy()
}
}, client.getCronKey(consts.TCPCronHeartbeatVerify))
}, client.getCronKey(CronHeartbeatVerify))
}
// 心跳
if gcron.Search(client.getCronKey(consts.TCPCronHeartbeat)) == nil {
_, _ = gcron.AddSingleton(client.Ctx, "@every 120s", func(ctx context.Context) {
if gcron.Search(client.getCronKey(CronHeartbeat)) == nil {
_, _ = gcron.AddSingleton(client.ctx, "@every 300s", func(ctx context.Context) {
client.serverHeartbeat()
}, client.getCronKey(consts.TCPCronHeartbeat))
}, client.getCronKey(CronHeartbeat))
}
}

View File

@@ -7,70 +7,73 @@ package tcp
import (
"context"
"fmt"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/consts"
"hotgo/internal/model/input/msgin"
"hotgo/utility/encrypt"
)
// serverLogin 心跳
func (client *Client) serverHeartbeat() {
if !client.isLogin.Val() {
return
}
ctx := gctx.New()
if err := client.Send(ctx, &msgin.ServerHeartbeat{}); err != nil {
client.Logger.Debugf(ctx, "client WriteMsg ServerHeartbeat err:%+v", err)
if err := client.conn.Send(ctx, &ServerHeartbeatReq{}); err != nil {
client.logger.Warningf(ctx, "client ServerHeartbeat Send err:%+v", err)
return
}
}
// serverLogin 服务登陆
func (client *Client) serverLogin() {
data := &msgin.ServerLogin{
Group: client.auth.Group,
Name: client.auth.Name,
auth := client.config.Auth
// 无需登录
if auth == nil {
return
}
data := &ServerLoginReq{
Name: auth.Name,
Extra: auth.Extra,
Group: auth.Group,
AppId: auth.AppId,
Timestamp: gtime.Timestamp(),
}
// 签名
data.Sign = encrypt.Md5ToString(fmt.Sprintf("%v%v%v", data.AppId, data.Timestamp, auth.SecretKey))
ctx := gctx.New()
if err := client.Send(ctx, data); err != nil {
client.Logger.Debugf(ctx, "client WriteMsg ServerLogin err:%+v", err)
if err := client.conn.Send(ctx, data); err != nil {
client.logger.Warningf(ctx, "client ServerLogin Send err:%+v", err)
return
}
}
// onResponseServerLogin 接收服务登陆响应结果
func (client *Client) onResponseServerLogin(ctx context.Context, args ...interface{}) {
var in *msgin.ResponseServerLogin
if err := gconv.Scan(args[0], &in); err != nil {
client.Logger.Infof(ctx, "onResponseServerLogin message Scan failed:%+v, args:%+v", err, args[0])
return
}
if in.Code != consts.TCPMsgCodeSuccess {
client.IsLogin = false
client.Logger.Warningf(ctx, "onResponseServerLogin quit err:%v", in.Message)
func (client *Client) onResponseServerLogin(ctx context.Context, req *ServerLoginRes) {
if err := req.GetError(); err != nil {
client.isLogin.Set(false)
client.logger.Warningf(ctx, "onResponseServerLogin destroy, err:%v", err)
client.Destroy()
return
}
client.IsLogin = true
client.isLogin.Set(true)
if client.loginEvent != nil {
client.loginEvent()
if client.config.LoginEvent != nil {
client.config.LoginEvent()
}
}
// onResponseServerHeartbeat 接收心跳响应结果
func (client *Client) onResponseServerHeartbeat(ctx context.Context, args ...interface{}) {
var in *msgin.ResponseServerHeartbeat
if err := gconv.Scan(args[0], &in); err != nil {
client.Logger.Infof(ctx, "onResponseServerHeartbeat message Scan failed:%+v, args:%+v", err, args)
func (client *Client) onResponseServerHeartbeat(ctx context.Context, req *ServerHeartbeatRes) {
if err := req.GetError(); err != nil {
client.logger.Warningf(ctx, "onResponseServerHeartbeat err:%v", err)
return
}
if in.Code != consts.TCPMsgCodeSuccess {
client.Logger.Warningf(ctx, "onResponseServerHeartbeat err:%v", in.Message)
return
}
client.heartbeat = gtime.Timestamp()
client.conn.Heartbeat = gtime.Timestamp()
}

View File

@@ -0,0 +1,177 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"context"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/grand"
"net"
"sync/atomic"
)
// AuthMeta 认证元数据
type AuthMeta struct {
Name string `json:"name"` // 客户端名称当同一个应用ID有多个客户端时请使用不同的名称区分。比如cron1,cron2
Extra g.Map `json:"extra"` // 自定义数据,可以传递一些额外的自定义数据
Group string `json:"group"` // 客户端分组
AppId string `json:"appId"` // 应用ID
SecretKey string `json:"secretKey"` // 应用秘钥
EndAt *gtime.Time `json:"endAt"` // 授权过期时间
Routes []string `json:"routes"` // 授权路由
}
// Conn tcp连接
type Conn struct {
CID int64 // 连接ID
Conn *gtcp.Conn // 连接对象
Auth *AuthMeta // 认证元数据
Heartbeat int64 // 心跳
FirstTime int64 // 首次连接时间
writeChan chan []byte // 发数据
closeFlag *gtype.Bool // 关闭标签
logger *glog.Logger // 日志处理器
msgParser *MsgParser // 消息处理器
}
var idCounter int64
func NewConn(conn *gtcp.Conn, logger *glog.Logger, msgParser *MsgParser) *Conn {
tcpConn := new(Conn)
tcpConn.CID = atomic.AddInt64(&idCounter, 1)
tcpConn.Conn = conn
tcpConn.Heartbeat = gtime.Timestamp()
tcpConn.FirstTime = gtime.Timestamp()
tcpConn.writeChan = make(chan []byte, 1000)
tcpConn.closeFlag = gtype.NewBool(false)
tcpConn.logger = logger
tcpConn.msgParser = msgParser
go func() {
for b := range tcpConn.writeChan {
if b == nil {
break
}
if err := conn.SendPkg(b); err != nil {
break
}
}
}()
return tcpConn
}
func (c *Conn) Run() error {
for {
data, err := c.Conn.RecvPkg()
if err != nil {
return gerror.NewCodef(gcode.CodeInvalidRequest, "read packet err:%+v conn closed", err)
}
if c.closeFlag.Val() {
return nil
}
msg, err := c.msgParser.Encoding(data)
if err != nil {
return gerror.NewCodef(gcode.CodeInternalError, "message encoding err:%+v conn closed", err)
}
ctx, err := c.bindContext(msg)
if err != nil {
return gerror.NewCodef(gcode.CodeInternalError, "bindContext err:%+v message: %+v", err, msg)
}
if err = c.msgParser.handleInterceptor(ctx, msg); err != nil {
c.logger.Warning(ctx, gerror.Wrap(err, "interceptor authentication failed"))
continue
}
if err = c.msgParser.handleRouterMsg(ctx, msg); err != nil {
return err
}
}
}
// RemoteAddr returns the remote network address, if known.
func (c *Conn) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
}
// LocalAddr returns the local network address, if known.
func (c *Conn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
// Write
func (c *Conn) Write(b []byte) {
if !c.closeFlag.Val() {
c.writeChan <- b
}
}
// Send 发送消息
func (c *Conn) Send(ctx context.Context, data interface{}) error {
if c.closeFlag.Val() {
return gerror.New("conn is closed")
}
b, err := c.msgParser.Decoding(ctx, data, "")
if err != nil {
return err
}
c.Write(b)
return nil
}
func (c *Conn) Close() {
if c.closeFlag.Val() {
return
}
c.closeFlag.Set(true)
close(c.writeChan)
c.Conn.Close()
}
// Request 发送消息并等待响应结果
func (c *Conn) Request(ctx context.Context, data interface{}) (interface{}, error) {
if c.closeFlag.Val() {
return nil, gerror.New("conn is closed")
}
msgId := grand.S(16)
b, err := c.msgParser.Decoding(ctx, data, msgId)
if err != nil {
return nil, err
}
return c.msgParser.rpc.Request(ctx, msgId, func() {
c.Write(b)
})
}
// RequestScan 发送消息并等待响应结果将结果保存在response中
func (c *Conn) RequestScan(ctx context.Context, data, response interface{}) error {
body, err := c.Request(ctx, data)
if err != nil {
return err
}
return gvar.New(body).Scan(response)
}
// bindContext 将用户身份绑定到上下文
func (c *Conn) bindContext(msg *Message) (ctx context.Context, err error) {
ctx = initCtx(gctx.New(), &Context{
Conn: c,
})
return SetCtxTraceID(ctx, msg.TraceId)
}

View File

@@ -0,0 +1,32 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
// 定时任务
const (
CronHeartbeatVerify = "tcpHeartbeatVerify"
CronHeartbeat = "tcpHeartbeat"
CronAuthVerify = "tcpAuthVerify"
)
const (
HeartbeatTimeout = 300 // tcp心跳超时默认300s
RPCTimeout = 10 // rpc通讯超时时间 默认10s
)
const (
ParseRouterErrInvalidParams = "register router[%v] method must have two params"
ParseRouterRPCErrInvalidParams = "register RPC router [%v] method must have two response params"
ParseRouterErrInvalidFirstParam = "the first params of the processing method that registers the router[%v] must be of type context.Context"
ParseRouterErrInvalidSecondParam = "the second params of the processing method that registers the router[%v] must be of type pointer to a struct"
)
type CtxKey string
// ContextKey 上下文
const (
ContextKey CtxKey = "tcpContext" // tcp上下文变量名称
)

View File

@@ -8,23 +8,21 @@ package tcp
import (
"context"
"github.com/gogf/gf/v2/net/gtrace"
"hotgo/internal/consts"
)
// Context tcp上下文
type Context struct {
Conn *Conn
}
// initCtx 初始化上下文对象指针到上下文对象中,以便后续的请求流程中可以修改
func initCtx(ctx context.Context, model *Context) (newCtx context.Context) {
if model.TraceID != "" {
newCtx, _ = gtrace.WithTraceID(ctx, model.TraceID)
} else {
newCtx = ctx
}
newCtx = context.WithValue(newCtx, consts.ContextTCPKey, model)
return
func initCtx(ctx context.Context, model *Context) context.Context {
return context.WithValue(ctx, ContextKey, model)
}
// GetCtx 获得上下文变量如果没有设置那么返回nil
func GetCtx(ctx context.Context) *Context {
value := ctx.Value(consts.ContextTCPKey)
value := ctx.Value(ContextKey)
if value == nil {
return nil
}
@@ -33,3 +31,20 @@ func GetCtx(ctx context.Context) *Context {
}
return nil
}
// ConnFromCtx retrieves and returns the Conn object from context.
func ConnFromCtx(ctx context.Context) *Conn {
user := GetCtx(ctx)
if user == nil {
return nil
}
return user.Conn
}
// SetCtxTraceID 将自定义跟踪ID注入上下文以进行传播
func SetCtxTraceID(ctx context.Context, traceID string) (context.Context, error) {
if len(traceID) > 0 {
return gtrace.WithTraceID(ctx, traceID)
}
return ctx, nil
}

View File

@@ -1,30 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gtime"
)
// AuthMeta 认证元数据
type AuthMeta struct {
Group string `json:"group"`
Name string `json:"name"`
AppId string `json:"appId"`
SecretKey string `json:"secretKey"`
EndAt *gtime.Time `json:"-"`
}
// Context tcp上下文
type Context struct {
Conn *gtcp.Conn `json:"conn"`
Auth *AuthMeta `json:"auth"` // 认证元数据
TraceID string `json:"traceID"` // 链路ID
}
// CallbackEvent 回调事件
type CallbackEvent func()

View File

@@ -0,0 +1,84 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
)
type ServerRes struct {
Code int `json:"code" example:"2000" description:"状态码"`
Message string `json:"message,omitempty" example:"操作成功" description:"提示消息"`
}
// SetCode 设置状态码
func (i *ServerRes) SetCode(code ...int) {
if len(code) > 0 {
i.Code = code[0]
return
}
// 默认值,转为成功的状态码
if i.Code == 0 {
i.Code = gcode.CodeOK.Code()
}
}
// SetMessage 设置提示消息
func (i *ServerRes) SetMessage(msg ...string) {
message := "操作成功"
if len(msg) > 0 {
message = msg[0]
return
}
i.Message = message
}
// SetError 设置响应中的错误
func (i *ServerRes) SetError(err error) {
if err != nil {
i.Code = gerror.Code(err).Code()
i.Message = err.Error()
}
return
}
// GetError 获取响应中的错误
func (i *ServerRes) GetError() (err error) {
if i.Code != gcode.CodeOK.Code() {
if i.Message == "" {
i.Message = "操作失败"
}
err = gerror.NewCode(gcode.New(i.Code, i.Message, ""))
}
return
}
// ServerLoginReq 服务登录
type ServerLoginReq struct {
Name string `json:"name" description:"客户端名称"` // 客户端名称当同一个应用ID有多个客户端时请使用不同的名称区分。比如cron1,cron2
Extra g.Map `json:"extra" description:"自定义数据"` // 自定义数据,可以传递一些额外的自定义数据
Group string `json:"group" description:"分组"`
AppId string `json:"appID" description:"应用ID"`
Timestamp int64 `json:"timestamp" description:"服务器时间戳"`
Sign string `json:"sign" description:"签名"`
}
// ServerLoginRes 响应服务登录
type ServerLoginRes struct {
ServerRes
}
// ServerHeartbeatReq 心跳
type ServerHeartbeatReq struct {
}
// ServerHeartbeatRes 响应心跳
type ServerHeartbeatRes struct {
ServerRes
}

View File

@@ -0,0 +1,222 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"context"
"encoding/json"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
"reflect"
"sync"
)
// MsgParser 消息处理器
type MsgParser struct {
mutex sync.Mutex // 路由锁
task RoutineTask // 投递协程任务
rpc *RPC // rpc
routers map[string]*RouteHandler // 已注册的路由
interceptors []Interceptor // 拦截器
}
// RoutineTask 投递协程任务
type RoutineTask func(ctx context.Context, task func())
// Interceptor 拦截器
type Interceptor func(ctx context.Context, msg *Message) (err error)
// Message 标准消息
type Message struct {
Router string `json:"router"` // 路由
TraceId string `json:"traceId"` // 链路ID
Data interface{} `json:"data"` // 数据
MsgId string `json:"msgId,omitempty"` // 消息IDrpc用
Error string `json:"error,omitempty"` // 消息错误rpc用
}
// NewMsgParser 初始化消息处理器
func NewMsgParser(task RoutineTask) *MsgParser {
m := new(MsgParser)
m.task = task
m.routers = make(map[string]*RouteHandler)
m.rpc = NewRPC(task)
return m
}
// RegisterRouter 注册路由
func (m *MsgParser) RegisterRouter(routers ...interface{}) (err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, router := range routers {
info, err := ParseRouteHandler(router, false)
if err != nil {
return err
}
if _, ok := m.routers[info.Id]; ok {
return gerror.Newf("server router duplicate registration:%v", info.Id)
}
m.routers[info.Id] = info
}
return
}
// RegisterRPCRouter 注册rpc路由
func (m *MsgParser) RegisterRPCRouter(routers ...interface{}) (err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, router := range routers {
info, err := ParseRouteHandler(router, true)
if err != nil {
return err
}
if _, ok := m.routers[info.Id]; ok {
return gerror.Newf("server rpc router duplicate registration:%v", info.Id)
}
m.routers[info.Id] = info
}
return
}
// RegisterInterceptor 注册拦截器
func (m *MsgParser) RegisterInterceptor(interceptors ...Interceptor) {
m.interceptors = append(interceptors, interceptors...)
return
}
// Encoding 消息编码
func (m *MsgParser) Encoding(data []byte) (*Message, error) {
var msg Message
if err := gconv.Scan(data, &msg); err != nil {
return nil, gerror.Newf("invalid package struct: %s", err.Error())
}
if msg.Router == "" {
return nil, gerror.Newf("message is not router: %+v", msg)
}
return &msg, nil
}
// Decoding 消息解码
func (m *MsgParser) Decoding(ctx context.Context, data interface{}, msgId string) ([]byte, error) {
message, err := m.doDecoding(ctx, data, msgId)
if err != nil {
return nil, err
}
return json.Marshal(message)
}
// Decoding 消息解码
func (m *MsgParser) doDecoding(ctx context.Context, data interface{}, msgId string) (*Message, error) {
msgType := reflect.TypeOf(data)
if msgType == nil || msgType.Kind() != reflect.Ptr {
return nil, gerror.Newf("json message pointer required: %+v", data)
}
message := &Message{Router: msgType.Elem().Name(), TraceId: gctx.CtxId(ctx), MsgId: msgId, Data: data}
return message, nil
}
// handleRouterMsg 处理路由消息
func (m *MsgParser) handleRouterMsg(ctx context.Context, msg *Message) error {
// rpc消息
if m.rpc.Response(ctx, msg) {
return nil
}
handler, ok := m.routers[msg.Router]
if !ok {
return gerror.NewCodef(gcode.CodeInternalError, "invalid message: %+v, or the router is not registered.", msg)
}
return m.doHandleRouterMsg(ctx, handler, msg)
}
// doHandleRouterMsg 处理路由消息
func (m *MsgParser) doHandleRouterMsg(ctx context.Context, handler *RouteHandler, msg *Message) (err error) {
var input = gutil.Copy(handler.Input.Interface())
if err = gjson.New(msg.Data).Scan(input); err != nil {
return gerror.NewCodef(
gcode.CodeInvalidParameter,
"router scan failed:%v to parse message:%v",
err, handler.Id)
}
args := []reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(input)}
m.task(ctx, func() {
results := handler.Func.Call(args)
if handler.IsRPC {
switch len(results) {
case 2:
var responseErr error
if !results[1].IsNil() {
if err, ok := results[1].Interface().(error); ok {
responseErr = err
}
}
responseMsg, deErr := m.doDecoding(ctx, results[0].Interface(), msg.MsgId)
if deErr != nil && responseErr == nil {
responseErr = deErr
return
}
if responseErr != nil {
responseMsg.Error = responseErr.Error()
}
b, err := json.Marshal(responseMsg)
if err != nil {
return
}
user := GetCtx(ctx)
user.Conn.Write(b)
}
return
}
})
return
}
// handleInterceptor 处理拦截器
func (m *MsgParser) handleInterceptor(ctx context.Context, msg *Message) (interceptErr error) {
for _, f := range m.interceptors {
if interceptErr = f(ctx, msg); interceptErr != nil {
break
}
}
if interceptErr == nil {
return
}
handler, ok := m.routers[msg.Router]
if !ok {
return
}
if handler.IsRPC {
var output = gutil.Copy(handler.Output.Interface())
response, doerr := m.doDecoding(ctx, output, msg.MsgId)
if doerr != nil {
return doerr
}
response.Error = interceptErr.Error()
b, err := json.Marshal(response)
if err != nil {
return err
}
ConnFromCtx(ctx).Write(b)
}
return
}

View File

@@ -1,27 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
type Response interface {
PkgResponse()
GetError() (err error)
}
// PkgResponse 打包响应消息
func PkgResponse(data interface{}) {
if c, ok := data.(Response); ok {
c.PkgResponse()
return
}
}
// GetResponseError 解析响应消息中的错误
func GetResponseError(data interface{}) (err error) {
if c, ok := data.(Response); ok {
return c.GetError()
}
return
}

View File

@@ -7,56 +7,88 @@ package tcp
import (
"context"
"encoding/json"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/text/gstr"
"reflect"
"runtime"
)
// RouterHandler 路由消息处理器
type RouterHandler func(ctx context.Context, args ...interface{})
// Message 路由消息
type Message struct {
Router string `json:"router"`
Data interface{} `json:"data"`
// RouteHandler 路由处理器
type RouteHandler struct {
Id string // 路由ID
IsRPC bool // 是否支持rpc协议
Func reflect.Value // 路由处理方法
Input reflect.Value // 输入参数
Output reflect.Value // 输出参数
}
// SendPkg 打包发送的数据包
func SendPkg(conn *gtcp.Conn, message *Message) error {
b, err := json.Marshal(message)
if err != nil {
return err
// ParseRouteHandler 解析路由
func ParseRouteHandler(router interface{}, isRPC bool) (info *RouteHandler, err error) {
funcName := runtime.FuncForPC(reflect.ValueOf(router).Pointer()).Name()
funcType := reflect.ValueOf(router).Type()
if funcType.NumIn() != 2 {
err = gerror.Newf(ParseRouterErrInvalidParams, funcName)
return
}
return conn.SendPkg(b)
}
// RecvPkg 解包
func RecvPkg(conn *gtcp.Conn) (*Message, error) {
if data, err := conn.RecvPkg(); err != nil {
return nil, err
} else {
var msg = new(Message)
if err = gconv.Scan(data, &msg); err != nil {
return nil, gerror.Newf("invalid package structure: %s", err.Error())
}
if msg.Router == "" {
return nil, gerror.Newf("message is not routed: %+v", msg)
}
return msg, err
if funcType.In(0) != reflect.TypeOf((*context.Context)(nil)).Elem() {
err = gerror.Newf(ParseRouterErrInvalidFirstParam, funcName)
return
}
}
// MsgPkg 打包消息
func MsgPkg(data interface{}, auth *AuthMeta, traceID string) string {
// 打包签名
msg := PkgSign(data, auth.AppId, auth.SecretKey, traceID)
// 打包响应消息
PkgResponse(data)
if msg == nil {
return ""
inputType := funcType.In(1)
if !(inputType.Kind() == reflect.Ptr && inputType.Elem().Kind() == reflect.Struct) {
err = gerror.Newf(ParseRouterErrInvalidSecondParam, funcName)
return
}
return msg.TraceID
// The request struct should be named as `xxxReq`.
if !gstr.HasSuffix(inputType.String(), `Req`) && !gstr.HasSuffix(inputType.String(), `Res`) {
err = gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid struct naming of the request: defined as "%s", but should be named with the "Req" or "Res" suffix, such as "XxxReq" or "XxxRes"`,
inputType.String(),
)
return
}
info = &RouteHandler{
Id: gstr.SubStrFromREx(inputType.String(), `.`),
IsRPC: isRPC,
Func: reflect.ValueOf(router),
Input: reflect.New(inputType.Elem()),
}
if !isRPC {
return
}
if funcType.NumOut() != 2 {
err = gerror.Newf(ParseRouterRPCErrInvalidParams, funcName)
return
}
outputType := funcType.Out(0)
// The response struct should be named as `xxxRes`.
if !gstr.HasSuffix(outputType.String(), `Res`) {
err = gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid struct naming for response: defined as "%s", but it should be named with "Res" suffix like "XxxRes"`,
outputType.String(),
)
return
}
if !funcType.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
err = gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid handler: defined as "%s", but the last output parameter should be type of "error"`,
reflect.TypeOf(funcType).String(),
)
return
}
info.Output = reflect.New(outputType.Elem())
return
}

View File

@@ -7,118 +7,95 @@ package tcp
import (
"context"
"fmt"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
"hotgo/internal/consts"
"hotgo/utility/simple"
"sync"
"time"
)
type Rpc struct {
ctx context.Context
// RPC .
type RPC struct {
mutex sync.Mutex
callbacks map[string]RpcRespFunc
msgGo *grpool.Pool // 消息处理协程池
logger *glog.Logger // 日志处理器
callbacks map[string]RPCResponseFunc
task RoutineTask
}
// RpcResp 响应结构
type RpcResp struct {
// RPCResponse 响应结构
type RPCResponse struct {
res interface{}
err error
}
type RpcRespFunc func(resp interface{}, err error)
type RPCResponseFunc func(resp interface{}, err error)
// NewRpc 初始化一个rpc协议
func NewRpc(ctx context.Context, msgGo *grpool.Pool, logger *glog.Logger) *Rpc {
return &Rpc{
ctx: ctx,
callbacks: make(map[string]RpcRespFunc),
msgGo: msgGo,
logger: logger,
// NewRPC 初始化RPC
func NewRPC(task RoutineTask) *RPC {
return &RPC{
task: task,
callbacks: make(map[string]RPCResponseFunc),
}
}
// GetCallId 获取回调id
func (r *Rpc) GetCallId(client *gtcp.Conn, traceID string) string {
return fmt.Sprintf("%v.%v", client.LocalAddr().String(), traceID)
}
// HandleMsg 处理rpc消息
func (r *Rpc) HandleMsg(ctx context.Context, data interface{}) bool {
user := GetCtx(ctx)
callId := r.GetCallId(user.Conn, user.TraceID)
if call, ok := r.callbacks[callId]; ok {
r.mutex.Lock()
delete(r.callbacks, callId)
r.mutex.Unlock()
ctx, cancel := context.WithCancel(ctx)
err := r.msgGo.AddWithRecover(ctx, func(ctx context.Context) {
call(data, nil)
cancel()
}, func(ctx context.Context, err error) {
r.logger.Warningf(ctx, "rpc HandleMsg msgGo exec err:%+v", err)
cancel()
})
if err != nil {
r.logger.Warningf(ctx, "rpc HandleMsg msgGo Add err:%+v", err)
}
return true
}
return false
}
// Request 发起rpc请求
func (r *Rpc) Request(callId string, send func()) (res interface{}, err error) {
var (
waitCh = make(chan struct{})
resCh = make(chan RpcResp, 1)
isClose = false
)
// Request 发起RPC请求
func (r *RPC) Request(ctx context.Context, msgId string, send func()) (res interface{}, err error) {
resCh := make(chan RPCResponse, 1)
isClose := gtype.NewBool(false)
defer func() {
isClose = true
isClose.Set(true)
close(resCh)
// 移除消息
if _, ok := r.callbacks[callId]; ok {
r.mutex.Lock()
delete(r.callbacks, callId)
r.mutex.Unlock()
}
r.popCallback(msgId)
}()
simple.SafeGo(r.ctx, func(ctx context.Context) {
close(waitCh)
// 加入回调
r.mutex.Lock()
r.callbacks[callId] = func(res interface{}, err error) {
if !isClose {
resCh <- RpcResp{res: res, err: err}
}
r.mutex.Lock()
r.callbacks[msgId] = func(res interface{}, err error) {
if !isClose.Val() {
resCh <- RPCResponse{res: res, err: err}
}
r.mutex.Unlock()
}
r.mutex.Unlock()
// 发送消息
send()
})
r.task(ctx, send)
<-waitCh
select {
case <-time.After(time.Second * consts.TCPRpcTimeout):
err = gerror.New("rpc response timeout")
case <-time.After(time.Second * RPCTimeout):
err = gerror.New("RPC response timeout")
return
case got := <-resCh:
return got.res, got.err
}
}
// Response RPC消息响应
func (r *RPC) Response(ctx context.Context, msg *Message) bool {
if len(msg.MsgId) == 0 {
return false
}
f, ok := r.popCallback(msg.MsgId)
if !ok {
return false
}
var msgError error
if len(msg.Error) > 0 {
msgError = gerror.New(msg.Error)
}
r.task(ctx, func() {
f(msg.Data, msgError)
})
return true
}
// popCallback 弹出回调
func (r *RPC) popCallback(msgId string) (RPCResponseFunc, bool) {
r.mutex.Lock()
defer r.mutex.Unlock()
call, ok := r.callbacks[msgId]
if ok {
delete(r.callbacks, msgId)
}
return call, ok
}

View File

@@ -7,24 +7,30 @@ package tcp
import (
"context"
"github.com/gogf/gf/v2/container/gtype"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gtcp"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
"hotgo/internal/consts"
"hotgo/utility/simple"
"reflect"
"sync"
"time"
)
// ClientConn 连接到tcp服务器的客户端对象
type ClientConn struct {
Conn *gtcp.Conn // 连接对象
Auth *AuthMeta // 认证元数据
heartbeat int64 // 心跳
// Server tcp服务器
type Server struct {
ctx context.Context // 上下文
name string // 服务器名称
addr string // 服务器地址
ln *gtcp.Server // tcp服务器
logger *glog.Logger // 日志处理器
wgLn sync.WaitGroup // 流程控制让tcp服务器按可控流程启动退出
closeFlag *gtype.Bool // 服务关闭标签
clients map[string]*Conn // 已登录的认证客户端
mutexConns sync.Mutex // 连接锁,主要用于客户端上下线
taskGo *grpool.Pool // 任务协程池
msgParser *MsgParser // 消息处理器
}
// ServerConfig tcp服务器配置
@@ -33,182 +39,121 @@ type ServerConfig struct {
Addr string // 监听地址
}
// Server tcp服务器对象结构
type Server struct {
Ctx context.Context // 上下文
Logger *glog.Logger // 日志处理器
addr string // 连接地址
name string // 服务器名称
rpc *Rpc // rpc协议
ln *gtcp.Server // tcp服务器
wgLn sync.WaitGroup // 状态控制主要用于tcp服务器能够按流程启动退出
mutex sync.Mutex // 服务器状态锁
closeFlag bool // 服务关闭标签
clients map[string]*ClientConn // 已登录的认证客户端
mutexConns sync.Mutex // 连接锁,主要用于客户端上下线
msgGo *grpool.Pool // 消息处理协程池
cronRouters map[string]RouterHandler // 定时任务路由
queueRouters map[string]RouterHandler // 队列路由
authRouters map[string]RouterHandler // 任务路由
}
// NewServer 初始一个tcp服务器对象
func NewServer(config *ServerConfig) (server *Server, err error) {
func NewServer(config *ServerConfig) (server *Server) {
server = new(Server)
server.ctx = gctx.New()
server.logger = g.Log("tcpServer")
baseErr := gerror.New("TCPServer start fail")
if config == nil {
err = gerror.New("config is nil")
server.logger.Fatal(server.ctx, gerror.Wrap(baseErr, "config is nil"))
return
}
if config.Addr == "" {
err = gerror.New("server address is not set")
server.logger.Fatal(server.ctx, gerror.Wrap(baseErr, "server address is not set"))
return
}
server = new(Server)
server.Ctx = gctx.New()
if config.Name == "" {
config.Name = simple.AppName(server.Ctx)
config.Name = simple.AppName(server.ctx)
}
server.addr = config.Addr
server.name = config.Name
server.ln = gtcp.NewServer(server.addr, server.accept, config.Name)
server.clients = make(map[string]*ClientConn)
server.closeFlag = false
server.Logger = g.Log("tcpServer")
server.msgGo = grpool.New(20)
server.rpc = NewRpc(server.Ctx, server.msgGo, server.Logger)
server.closeFlag = gtype.NewBool(false)
server.clients = make(map[string]*Conn)
server.taskGo = grpool.New(20)
server.msgParser = NewMsgParser(server.handleRoutineTask)
server.startCron()
return
}
// accept
func (server *Server) accept(conn *gtcp.Conn) {
defer func() {
server.mutexConns.Lock()
_ = conn.Close()
// 从登录列表中移除
delete(server.clients, conn.RemoteAddr().String())
server.mutexConns.Unlock()
tcpConn := NewConn(conn, server.logger, server.msgParser)
server.AddClient(tcpConn)
go func() {
if err := tcpConn.Run(); err != nil {
server.logger.Info(server.ctx, err)
}
// cleanup
tcpConn.Close()
server.RemoveClient(tcpConn)
}()
}
for {
msg, err := RecvPkg(conn)
if err != nil {
server.Logger.Debugf(server.Ctx, "RecvPkg err:%+v, client closed.", err)
break
}
client := server.getLoginConn(conn)
switch msg.Router {
case "ServerLogin": // 服务登录
// 初始化上下文
ctx := initCtx(gctx.New(), &Context{
Conn: conn,
})
server.doHandleRouterMsg(ctx, server.onServerLogin, msg.Data)
case "ServerHeartbeat": // 心跳
if client == nil {
server.Logger.Infof(server.Ctx, "conn not connected, ignore the heartbeat, msg:%+v", msg)
continue
}
// 初始化上下文
ctx := initCtx(gctx.New(), &Context{})
server.doHandleRouterMsg(ctx, server.onServerHeartbeat, msg.Data, client)
default: // 通用路由消息处理
if client == nil {
server.Logger.Warningf(server.Ctx, "conn is not logged in but sends a routing message. actively conn disconnect, msg:%+v", msg)
time.Sleep(time.Second)
conn.Close()
return
}
server.handleRouterMsg(msg, client)
}
// RemoveClient 移除客户端
func (server *Server) RemoveClient(conn *Conn) {
label := server.ClientLabel(conn.Conn)
if _, ok := server.clients[label]; ok {
server.mutexConns.Lock()
delete(server.clients, label)
server.mutexConns.Unlock()
}
}
// handleRouterMsg 处理路由消息
func (server *Server) handleRouterMsg(msg *Message, client *ClientConn) {
// 验证签名
in, err := VerifySign(msg.Data, client.Auth.AppId, client.Auth.SecretKey)
if err != nil {
server.Logger.Warningf(server.Ctx, "handleRouterMsg VerifySign err:%+v message: %+v", err, msg)
return
}
// 初始化上下文
ctx := initCtx(gctx.New(), &Context{
Conn: client.Conn,
Auth: client.Auth,
TraceID: in.TraceID,
})
// 响应rpc消息
if server.rpc.HandleMsg(ctx, msg.Data) {
return
}
handle := func(routers map[string]RouterHandler, group string) {
if routers == nil {
server.Logger.Debugf(server.Ctx, "handleRouterMsg route is not initialized %v message: %+v", group, msg)
return
}
f, ok := routers[msg.Router]
if !ok {
server.Logger.Debugf(server.Ctx, "handleRouterMsg invalid %v message: %+v", group, msg)
return
}
server.doHandleRouterMsg(ctx, f, msg.Data)
}
switch client.Auth.Group {
case consts.TCPClientGroupCron:
handle(server.cronRouters, client.Auth.Group)
case consts.TCPClientGroupQueue:
handle(server.queueRouters, client.Auth.Group)
case consts.TCPClientGroupAuth:
handle(server.authRouters, client.Auth.Group)
default:
server.Logger.Warningf(server.Ctx, "group is not registered: %+v", client.Auth.Group)
}
// AddClient 添加客户端
func (server *Server) AddClient(conn *Conn) {
server.mutexConns.Lock()
server.clients[server.ClientLabel(conn.Conn)] = conn
server.mutexConns.Unlock()
}
// doHandleRouterMsg 处理路由消息
func (server *Server) doHandleRouterMsg(ctx context.Context, fun RouterHandler, args ...interface{}) {
ctx, cancel := context.WithCancel(ctx)
err := server.msgGo.AddWithRecover(ctx,
func(ctx context.Context) {
fun(ctx, args...)
cancel()
},
func(ctx context.Context, err error) {
server.Logger.Warningf(ctx, "doHandleRouterMsg msgGo exec err:%+v", err)
cancel()
},
)
if err != nil {
server.Logger.Warningf(ctx, "doHandleRouterMsg msgGo Add err:%+v", err)
// AuthClient 认证客户端
func (server *Server) AuthClient(conn *Conn, auth *AuthMeta) {
label := server.ClientLabel(conn.Conn)
client, ok := server.clients[label]
if !ok {
server.logger.Debugf(server.ctx, "authClient client does not exist:%v", label)
return
}
client.Auth = auth
}
// getLoginConn 获取指定已登录的连接
func (server *Server) getLoginConn(conn *gtcp.Conn) *ClientConn {
client, ok := server.clients[conn.RemoteAddr().String()]
// ClientLabel 客户端标识
func (server *Server) ClientLabel(conn *gtcp.Conn) string {
return conn.RemoteAddr().String()
}
// GetClient 获取指定连接
func (server *Server) GetClient(conn *gtcp.Conn) *Conn {
client, ok := server.clients[server.ClientLabel(conn)]
if !ok {
return nil
}
return client
}
// getLoginConn 获取指定appid的所有连接
func (server *Server) getAppIdClients(appid string) (list []*ClientConn) {
// GetClients 获取所有连接
func (server *Server) GetClients() map[string]*Conn {
return server.clients
}
// GetClientById 通过连接ID获取连接
func (server *Server) GetClientById(id int64) *Conn {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
for _, v := range server.clients {
if v.Auth.AppId == appid {
if v.CID == id {
return v
}
}
return nil
}
// GetAppIdClients 获取指定appid的所有连接
func (server *Server) GetAppIdClients(appid string) (list []*Conn) {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
for _, v := range server.clients {
if v.Auth != nil && v.Auth.AppId == appid {
list = append(list, v)
}
}
@@ -216,70 +161,63 @@ func (server *Server) getAppIdClients(appid string) (list []*ClientConn) {
}
// GetGroupClients 获取指定分组的所有连接
func (server *Server) GetGroupClients(group string) (list []*ClientConn) {
func (server *Server) GetGroupClients(group string) (list []*Conn) {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
for _, v := range server.clients {
if v.Auth.Group == group {
if v.Auth != nil && v.Auth.Group == group {
list = append(list, v)
}
}
return
}
// RegisterAuthRouter 注册授权路由
func (server *Server) RegisterAuthRouter(routers map[string]RouterHandler) {
server.mutex.Lock()
defer server.mutex.Unlock()
if server.authRouters == nil {
server.authRouters = make(map[string]RouterHandler)
}
for i, router := range routers {
_, ok := server.authRouters[i]
if ok {
server.Logger.Debugf(server.Ctx, "server authRouters duplicate registration:%v", i)
continue
}
server.authRouters[i] = router
}
// GetAppIdOnline 获取指定appid的在线数量
func (server *Server) GetAppIdOnline(appid string) int {
return len(server.GetAppIdClients(appid))
}
// RegisterCronRouter 注册任务路由
func (server *Server) RegisterCronRouter(routers map[string]RouterHandler) {
server.mutex.Lock()
defer server.mutex.Unlock()
if server.cronRouters == nil {
server.cronRouters = make(map[string]RouterHandler)
}
for i, router := range routers {
_, ok := server.cronRouters[i]
if ok {
server.Logger.Debugf(server.Ctx, "server cronRouters duplicate registration:%v", i)
continue
}
server.cronRouters[i] = router
}
// GetAllOnline 获取所有在线数量
func (server *Server) GetAllOnline() int {
return len(server.clients)
}
// RegisterQueueRouter 注册队列路由
func (server *Server) RegisterQueueRouter(routers map[string]RouterHandler) {
server.mutex.Lock()
defer server.mutex.Unlock()
// GetAuthOnline 获取所有已登录认证在线数量
func (server *Server) GetAuthOnline() int {
server.mutexConns.Lock()
defer server.mutexConns.Unlock()
if server.queueRouters == nil {
server.queueRouters = make(map[string]RouterHandler)
}
for i, router := range routers {
_, ok := server.queueRouters[i]
if ok {
server.Logger.Debugf(server.Ctx, "server queueRouters duplicate registration:%v", i)
continue
online := 0
for _, v := range server.clients {
if v.Auth != nil {
online++
}
server.queueRouters[i] = router
}
return online
}
// RegisterRouter 注册路由
func (server *Server) RegisterRouter(routers ...interface{}) {
err := server.msgParser.RegisterRouter(routers...)
if err != nil {
server.logger.Fatal(server.ctx, err)
}
return
}
// RegisterRPCRouter 注册RPC路由
func (server *Server) RegisterRPCRouter(routers ...interface{}) {
err := server.msgParser.RegisterRPCRouter(routers...)
if err != nil {
server.logger.Fatal(server.ctx, err)
}
return
}
// RegisterInterceptor 注册拦截器
func (server *Server) RegisterInterceptor(interceptors ...Interceptor) {
server.msgParser.RegisterInterceptor(interceptors...)
}
// Listen 监听服务
@@ -291,77 +229,68 @@ func (server *Server) Listen() (err error) {
// Close 关闭服务
func (server *Server) Close() {
if server.closeFlag {
if server.closeFlag.Val() {
return
}
server.closeFlag = true
server.closeFlag.Set(true)
server.stopCron()
server.mutexConns.Lock()
for _, client := range server.clients {
_ = client.Conn.Close()
client.Conn.Close()
}
server.clients = nil
server.mutexConns.Unlock()
if server.ln != nil {
_ = server.ln.Close()
server.ln.Close()
}
server.wgLn.Wait()
}
// IsClose 服务是否关闭
func (server *Server) IsClose() bool {
return server.closeFlag
return server.closeFlag.Val()
}
// Write 向指定客户端发送消息
func (server *Server) Write(conn *gtcp.Conn, data interface{}) (err error) {
if server.closeFlag {
return gerror.New("service is down")
}
msgType := reflect.TypeOf(data)
if msgType == nil || msgType.Kind() != reflect.Ptr {
return gerror.Newf("json message pointer required: %+v", data)
}
msg := &Message{Router: msgType.Elem().Name(), Data: data}
return SendPkg(conn, msg)
}
// Send 发送消息
func (server *Server) Send(ctx context.Context, client *ClientConn, data interface{}) (err error) {
MsgPkg(data, client.Auth, gctx.CtxId(ctx))
return server.Write(client.Conn, data)
}
// Reply 回复消息
func (server *Server) Reply(ctx context.Context, data interface{}) (err error) {
user := GetCtx(ctx)
if user == nil {
err = gerror.New("获取回复用户信息失败")
return
}
MsgPkg(data, user.Auth, user.TraceID)
return server.Write(user.Conn, data)
}
// RpcRequest 向指定客户端发送消息并等待响应结果
func (server *Server) RpcRequest(ctx context.Context, client *ClientConn, data interface{}) (res interface{}, err error) {
var (
traceID = MsgPkg(data, client.Auth, gctx.CtxId(ctx))
key = server.rpc.GetCallId(client.Conn, traceID)
// handleRoutineTask 处理协程任务
func (server *Server) handleRoutineTask(ctx context.Context, task func()) {
ctx, cancel := context.WithCancel(ctx)
err := server.taskGo.AddWithRecover(ctx,
func(ctx context.Context) {
task()
cancel()
},
func(ctx context.Context, err error) {
server.logger.Warningf(ctx, "routineTask exec err:%+v", err)
cancel()
},
)
if traceID == "" {
err = gerror.New("traceID is required")
if err != nil {
server.logger.Warningf(ctx, "routineTask add err:%+v", err)
}
}
// GetRoutes 获取所有路由
func (server *Server) GetRoutes() (routes []RouteHandler) {
if server.msgParser.routers == nil {
return
}
return server.rpc.Request(key, func() {
_ = server.Write(client.Conn, data)
})
for _, v := range server.msgParser.routers {
routes = append(routes, *v)
}
return
}
// Request 向指定客户端发送消息并等待响应结果
func (server *Server) Request(ctx context.Context, client *Conn, data interface{}) (interface{}, error) {
return client.Request(ctx, data)
}
// RequestScan 向指定客户端发送消息并等待响应结果将结果保存在response中
func (server *Server) RequestScan(ctx context.Context, client *Conn, data, response interface{}) error {
return client.RequestScan(ctx, data, response)
}

View File

@@ -10,7 +10,6 @@ import (
"fmt"
"github.com/gogf/gf/v2/os/gcron"
"github.com/gogf/gf/v2/os/gtime"
"hotgo/internal/consts"
)
// getCronKey 生成服务端定时任务名称
@@ -28,32 +27,35 @@ func (server *Server) stopCron() {
// startCron 启动定时任务
func (server *Server) startCron() {
// 心跳超时检查
if gcron.Search(server.getCronKey(consts.TCPCronHeartbeatVerify)) == nil {
_, _ = gcron.AddSingleton(server.Ctx, "@every 300s", func(ctx context.Context) {
if server.clients == nil {
if gcron.Search(server.getCronKey(CronHeartbeatVerify)) == nil {
gcron.AddSingleton(server.ctx, "@every 300s", func(ctx context.Context) {
if server == nil || server.clients == nil {
return
}
for _, client := range server.clients {
if client.heartbeat < gtime.Timestamp()-consts.TCPHeartbeatTimeout {
_ = client.Conn.Close()
server.Logger.Debugf(server.Ctx, "client heartbeat timeout, close conn. auth:%+v", client.Auth)
if client.Heartbeat < gtime.Timestamp()-HeartbeatTimeout {
client.Conn.Close()
server.logger.Debugf(server.ctx, "client heartbeat timeout, close conn. auth:%+v", client.Auth)
}
}
}, server.getCronKey(consts.TCPCronHeartbeatVerify))
}, server.getCronKey(CronHeartbeatVerify))
}
// 认证检查
if gcron.Search(server.getCronKey(consts.TCPCronAuthVerify)) == nil {
_, _ = gcron.AddSingleton(server.Ctx, "@every 300s", func(ctx context.Context) {
if server.clients == nil {
if gcron.Search(server.getCronKey(CronAuthVerify)) == nil {
gcron.AddSingleton(server.ctx, "@every 300s", func(ctx context.Context) {
if server == nil || server.clients == nil {
return
}
for _, client := range server.clients {
if client.Auth == nil {
continue
}
if client.Auth.EndAt.Before(gtime.Now()) {
_ = client.Conn.Close()
server.Logger.Debugf(server.Ctx, "client auth expired, close conn. auth:%+v", client.Auth)
client.Conn.Close()
server.logger.Debugf(server.ctx, "client auth expired, close conn. auth:%+v", client.Auth)
}
}
}, server.getCronKey(consts.TCPCronAuthVerify))
}, server.getCronKey(CronAuthVerify))
}
}

View File

@@ -1,166 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"context"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/consts"
"hotgo/internal/model/entity"
"hotgo/internal/model/input/msgin"
"hotgo/utility/convert"
)
// onServerLogin 处理客户端登录
func (server *Server) onServerLogin(ctx context.Context, args ...interface{}) {
var (
in = new(msgin.ServerLogin)
user = GetCtx(ctx)
res = new(msgin.ResponseServerLogin)
models *entity.SysServeLicense
)
if err := gconv.Scan(args[0], &in); err != nil {
server.Logger.Warningf(ctx, "onServerLogin message Scan failed:%+v, args:%+v", err, args)
return
}
err := g.Model("sys_serve_license").Ctx(ctx).
Where("appid = ?", in.AppId).
Scan(&models)
if err != nil {
res.Code = 1
res.Message = err.Error()
_ = server.Write(user.Conn, res)
return
}
if models == nil {
res.Code = 2
res.Message = "授权信息不存在"
_ = server.Write(user.Conn, res)
return
}
// 验证签名
if _, err = VerifySign(in, models.Appid, models.SecretKey); err != nil {
res.Code = 3
res.Message = "签名错误,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
if models.Status != consts.StatusEnabled {
res.Code = 4
res.Message = "授权已禁用,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
if models.Group != in.Group {
res.Code = 5
res.Message = "你登录的授权分组未得到授权,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
if models.EndAt.Before(gtime.Now()) {
res.Code = 6
res.Message = "授权已过期,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
allowedIps := convert.IpFilterStrategy(models.AllowedIps)
if _, ok := allowedIps["*"]; !ok {
ip := gstr.StrTillEx(user.Conn.RemoteAddr().String(), ":")
if _, ok2 := allowedIps[ip]; !ok2 {
res.Code = 7
res.Message = "IP(" + ip + ")未授权,请联系管理员"
_ = server.Write(user.Conn, res)
return
}
}
// 检查是否存在多地登录,如果连接超出上限,直接将所有已连接断开
clients := server.getAppIdClients(models.Appid)
online := len(clients) + 1
if online > models.OnlineLimit {
res2 := new(msgin.ResponseServerLogin)
res2.Code = 8
res2.Message = "授权登录端超出上限已进行记录。请立即终止操作。如有疑问请联系管理员"
for _, client := range clients {
_ = server.Write(client.Conn, res2)
_ = client.Conn.Close()
}
// 当前连接也踢掉
_ = server.Write(user.Conn, res2)
_ = user.Conn.Close()
return
}
server.mutexConns.Lock()
server.clients[user.Conn.RemoteAddr().String()] = &ClientConn{
Conn: user.Conn,
Auth: &AuthMeta{
Group: in.Group,
Name: in.Name,
AppId: in.AppId,
SecretKey: models.SecretKey,
EndAt: models.EndAt,
},
heartbeat: gtime.Timestamp(),
}
server.mutexConns.Unlock()
_, err = g.Model("sys_serve_license").Ctx(ctx).
Where("id = ?", models.Id).Data(g.Map{
"online": online,
"login_times": models.LoginTimes + 1,
"last_login_at": gtime.Now(),
"last_active_at": gtime.Now(),
"remote_addr": user.Conn.RemoteAddr().String(),
}).Update()
if err != nil {
server.Logger.Warningf(ctx, "onServerLogin Update err:%+v", err)
}
res.AppId = in.AppId
res.Code = consts.TCPMsgCodeSuccess
_ = server.Write(user.Conn, res)
}
// onServerHeartbeat 处理客户端心跳
func (server *Server) onServerHeartbeat(ctx context.Context, args ...interface{}) {
var (
in *msgin.ServerHeartbeat
res = new(msgin.ResponseServerHeartbeat)
)
if err := gconv.Scan(args[0], &in); err != nil {
server.Logger.Warningf(ctx, "onServerHeartbeat message Scan failed:%+v, args:%+v", err, args)
return
}
client := args[1].(*ClientConn)
client.heartbeat = gtime.Timestamp()
_, err := g.Model("sys_serve_license").Ctx(ctx).
Where("appid = ?", client.Auth.AppId).Data(g.Map{
"last_active_at": gtime.Now(),
}).Update()
if err != nil {
server.Logger.Warningf(ctx, "onServerHeartbeat Update err:%+v", err)
}
res.Code = consts.TCPMsgCodeSuccess
_ = server.Write(client.Conn, res)
}

View File

@@ -1,49 +0,0 @@
// Package tcp
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package tcp
import (
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/util/gconv"
"hotgo/internal/model/input/msgin"
)
type Sign interface {
SetSign(appId, secretKey string) *msgin.RpcMsg
SetTraceID(traceID string)
}
// PkgSign 打包签名
func PkgSign(data interface{}, appId, secretKey, traceID string) *msgin.RpcMsg {
if c, ok := data.(Sign); ok {
c.SetTraceID(traceID)
return c.SetSign(appId, secretKey)
}
return nil
}
// VerifySign 验证签名
func VerifySign(data interface{}, appId, secretKey string) (in *msgin.RpcMsg, err error) {
// 无密钥,无需签名
if secretKey == "" {
return
}
if err = gconv.Scan(data, &in); err != nil {
return
}
if appId != in.AppId {
err = gerror.New("appId invalid")
return
}
if in.Sign != in.GetSign(secretKey) {
err = gerror.New("sign invalid")
return
}
return
}

View File

@@ -0,0 +1,128 @@
package tcp_test
import (
"context"
"fmt"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/test/gtest"
"hotgo/internal/library/network/tcp"
"testing"
"time"
)
var T *testing.T // 声明一个全局的 *testing.T 变量
type TestMsgReq struct {
Name string `json:"name"`
}
type TestMsgRes struct {
tcp.ServerRes
}
type TestRPCMsgReq struct {
Name string `json:"name"`
}
type TestRPCMsgRes struct {
tcp.ServerRes
}
func onTestMsg(ctx context.Context, req *TestMsgReq) {
fmt.Printf("服务器收到消息 ==> onTestMsg:%+v\n", req)
conn := tcp.ConnFromCtx(ctx)
gtest.C(T, func(t *gtest.T) {
t.AssertNE(conn, nil)
})
res := new(TestMsgRes)
res.Message = fmt.Sprintf("你的名字:%v", req.Name)
conn.Send(ctx, res)
}
func onResponseTestMsg(ctx context.Context, req *TestMsgRes) {
fmt.Printf("客户端收到响应消息 ==> TestMsgRes:%+v\n", req)
err := req.GetError()
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}
func onTestRPCMsg(ctx context.Context, req *TestRPCMsgReq) (res *TestRPCMsgRes, err error) {
fmt.Printf("服务器收到消息 ==> onTestRPCMsg:%+v\n", req)
res = new(TestRPCMsgRes)
res.Message = fmt.Sprintf("你的名字:%v", req.Name)
return
}
func startTCPServer() {
serv := tcp.NewServer(&tcp.ServerConfig{
Name: "hotgo",
Addr: ":8002",
})
// 注册路由
serv.RegisterRouter(
onTestMsg,
)
// 注册RPC路由
serv.RegisterRPCRouter(
onTestRPCMsg,
)
// 服务监听
err := serv.Listen()
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}
// 一个基本的消息收发
func TestSendMsg(t *testing.T) {
T = t
go startTCPServer()
ctx := gctx.New()
client := tcp.NewClient(&tcp.ClientConfig{
Addr: "127.0.0.1:8002",
})
// 注册路由
client.RegisterRouter(
onResponseTestMsg,
)
go func() {
err := client.Start()
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}()
// 确保服务都启动完成
time.Sleep(time.Second * 1)
// 拿到客户端的连接
conn := client.Conn()
gtest.C(T, func(t *gtest.T) {
t.AssertNE(conn, nil)
})
// 向服务器发送tcp消息不会阻塞程序执行
err := conn.Send(ctx, &TestMsgReq{Name: "Tom"})
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
// 向服务器发送rpc消息会等待服务器响应结果直到拿到结果或响应超时才会继续
var res TestRPCMsgRes
if err = conn.RequestScan(ctx, &TestRPCMsgReq{Name: "Tony"}, &res); err != nil {
gtest.C(T, func(t *gtest.T) {
t.AssertNil(err)
})
}
fmt.Printf("客户端收到RPC消息响应 ==> TestRPCMsgRes:%+v\n", res)
time.Sleep(time.Second * 1)
}