Skip to content

Commit

Permalink
chore: Refactor transporter dialing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 committed Dec 19, 2023
1 parent 5249012 commit 1a8420c
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type RelayTransporter interface {
HandleUDPConn(uaddr *net.UDPAddr, local *net.UDPConn)

// TCP相关
dialRemote(remote *lb.Node) (net.Conn, error)
HandleTCPConn(c net.Conn, remote *lb.Node) error
GetRemote() *lb.Node
}
Expand Down
25 changes: 17 additions & 8 deletions internal/transporter/mtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"time"

"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/web"
"github.com/xtaci/smux"
Expand All @@ -17,17 +18,24 @@ type MTCP struct {
mtp *smuxTransporter
}

func (s *MTCP) HandleTCPConn(c net.Conn, remote *lb.Node) error {
defer c.Close()
func (s *MTCP) dialRemote(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
mwsc, err := s.mtp.Dial(context.TODO(), remote.Address)
mtcpc, err := s.mtp.Dial(context.TODO(), remote.Address)
if err != nil {
return nil, err
}
web.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(time.Since(t1).Milliseconds()))
return mtcpc, nil
}

func (s *MTCP) HandleTCPConn(c net.Conn, remote *lb.Node) error {
defer c.Close()
mctpc, err := s.dialRemote(remote)
if err != nil {
return err
}
defer mwsc.Close()
s.L.Infof("HandleTCPConn from:%s to:%s", c.LocalAddr(), remote.Address)
return transport(c, mwsc, remote.Label)
return transport(c, mctpc, remote.Label)
}

type MTCPServer struct {
Expand Down Expand Up @@ -126,15 +134,16 @@ func (s *MTCPServer) Close() error {
}

type MTCPClient struct {
l *zap.SugaredLogger
l *zap.SugaredLogger
dialer *net.Dialer
}

func NewMTCPClient(l *zap.SugaredLogger) *MTCPClient {
return &MTCPClient{l: l}
return &MTCPClient{l: l, dialer: &net.Dialer{Timeout: constant.DialTimeOut}}
}

func (c *MTCPClient) InitNewSession(ctx context.Context, addr string) (*smux.Session, error) {
rc, err := net.Dial("tcp", addr)
rc, err := c.dialer.Dial("tcp", addr)
if err != nil {
return nil, err
}
Expand Down
18 changes: 11 additions & 7 deletions internal/transporter/mwss.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ type Mwss struct {
mtp *smuxTransporter
}

func (s *Mwss) HandleTCPConn(c net.Conn, remote *lb.Node) error {
defer c.Close()
func (s *Mwss) dialRemote(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
mwsc, err := s.mtp.Dial(context.TODO(), remote.Address+"/mwss/")
if err != nil {
return nil, err
}
web.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(time.Since(t1).Milliseconds()))
return mwsc, nil
}

func (s *Mwss) HandleTCPConn(c net.Conn, remote *lb.Node) error {
defer c.Close()
mwsc, err := s.dialRemote(remote)
if err != nil {
return err
}
Expand Down Expand Up @@ -146,11 +154,7 @@ type MWSSClient struct {
}

func NewMWSSClient(l *zap.SugaredLogger) *MWSSClient {
dialer := &ws.Dialer{
TLSConfig: mytls.DefaultTLSConfig,
Timeout: constant.DialTimeOut,
}

dialer := &ws.Dialer{TLSConfig: mytls.DefaultTLSConfig, Timeout: constant.DialTimeOut}
return &MWSSClient{
dialer: dialer,
l: l,
Expand Down
4 changes: 2 additions & 2 deletions internal/transporter/raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,13 @@ func (raw *Raw) GetRemote() *lb.Node {
}

func (raw *Raw) dialRemote(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
d := net.Dialer{Timeout: constant.DialTimeOut}
rc, err := d.Dial("tcp", remote.Address)
if err != nil {
return nil, err
}
web.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(time.Since(t1).Milliseconds()))
return rc, nil
}

Expand All @@ -130,12 +132,10 @@ func (raw *Raw) HandleTCPConn(c net.Conn, remote *lb.Node) error {
defer web.CurConnectionCount.WithLabelValues(remote.Label, web.METRIC_CONN_TYPE_TCP).Dec()

defer c.Close()
t1 := time.Now()
rc, err := raw.dialRemote(remote)
if err != nil {
return err
}
web.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(time.Since(t1).Milliseconds()))
raw.L.Infof("HandleTCPConn from %s to %s", c.LocalAddr(), remote.Address)
defer rc.Close()
return transport(c, rc, remote.Label)
Expand Down
2 changes: 2 additions & 0 deletions internal/transporter/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ type Ws struct {
}

func (s *Ws) dialRemote(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
d := ws.Dialer{Timeout: constant.DialTimeOut}
wsc, _, _, err := d.Dial(context.TODO(), remote.Address+"/ws/")
if err != nil {
return nil, err
}
web.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(time.Since(t1).Milliseconds()))
return wsc, nil
}

Expand Down
16 changes: 11 additions & 5 deletions internal/transporter/wss.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,23 @@ type Wss struct {
*Raw
}

func (s *Wss) HandleTCPConn(c net.Conn, remote *lb.Node) error {
defer c.Close()

d := ws.Dialer{TLSConfig: mytls.DefaultTLSConfig}
func (s *Wss) dialRemote(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
d := ws.Dialer{TLSConfig: mytls.DefaultTLSConfig}
wsc, _, _, err := d.Dial(context.TODO(), remote.Address+"/wss/")
if err != nil {
return nil, err
}
web.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(time.Since(t1).Milliseconds()))
return wsc, nil
}

func (s *Wss) HandleTCPConn(c net.Conn, remote *lb.Node) error {
defer c.Close()
wsc, err := s.dialRemote(remote)
if err != nil {
return err
}
defer wsc.Close()
s.L.Infof("HandleTCPConn from %s to %s", c.RemoteAddr(), remote.Address)
return transport(c, wsc, remote.Label)
}
Expand Down

0 comments on commit 1a8420c

Please sign in to comment.