From 6fb3457bd25a401940e7d87cd7671af0dfe6a25c Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Wed, 5 Sep 2018 11:43:53 +0800 Subject: [PATCH] optimise tunnel nat udp module --- services/mux/mux_client.go | 13 +- services/mux/mux_server.go | 21 ++-- services/tunnel/tunnel_bridge.go | 1 + services/tunnel/tunnel_client.go | 170 +++++++++++++++++++-------- services/tunnel/tunnel_server.go | 196 +++++++++++++++++-------------- utils/functions.go | 2 +- 6 files changed, 241 insertions(+), 162 deletions(-) diff --git a/services/mux/mux_client.go b/services/mux/mux_client.go index b34e77b..31f187d 100644 --- a/services/mux/mux_client.go +++ b/services/mux/mux_client.go @@ -36,6 +36,7 @@ type MuxClientArgs struct { } type ClientUDPConnItem struct { conn *smux.Stream + isActive bool touchtime int64 srcAddr *net.UDPAddr localAddr *net.UDPAddr @@ -292,11 +293,6 @@ func (s *MuxClient) ServeUDP(inConn *smux.Stream, localAddr, ID string) { } (*item).touchtime = time.Now().Unix() go (*item).udpConn.Write(body) - //_, err = (*item).udpConn.Write(body) - // if err != nil { - // s.log.Printf("send udp packet to %s fail, err : %s", item.localAddr, err) - // return - // } } } func (s *MuxClient) UDPRevecive(key, ID string) { @@ -334,16 +330,11 @@ func (s *MuxClient) UDPRevecive(key, ID string) { return } }() - // _, err = cui.conn.Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n])) - // if err != nil { - // s.log.Printf("send udp packet to bridge fail, err : %s", err) - // return - // } } }() } func (s *MuxClient) UDPGCDeamon() { - gctime := int64(60) + gctime := int64(30) go func() { if s.isStop { return diff --git a/services/mux/mux_server.go b/services/mux/mux_server.go index aab3c24..a38ceff 100644 --- a/services/mux/mux_server.go +++ b/services/mux/mux_server.go @@ -173,6 +173,7 @@ func NewMuxServer() services.Service { lockChn: make(chan bool, 1), sessions: mapx.NewConcurrentMap(), isStop: false, + udpConns: mapx.NewConcurrentMap(), } } @@ -181,7 +182,7 @@ type MuxUDPPacketItem struct { localAddr *net.UDPAddr srcAddr *net.UDPAddr } -type UDPConnItem struct { +type MuxUDPConnItem struct { conn *net.Conn touchtime int64 srcAddr *net.UDPAddr @@ -429,7 +430,7 @@ func (s *MuxServer) getParentConn() (conn net.Conn, err error) { return } func (s *MuxServer) UDPGCDeamon() { - gctime := int64(60) + gctime := int64(30) go func() { if s.isStop { return @@ -439,10 +440,10 @@ func (s *MuxServer) UDPGCDeamon() { <-timer.C gcKeys := []string{} s.udpConns.IterCb(func(key string, v interface{}) { - if time.Now().Unix()-v.(*UDPConnItem).touchtime > gctime { - (*(v.(*UDPConnItem).conn)).Close() + if time.Now().Unix()-v.(*MuxUDPConnItem).touchtime > gctime { + (*(v.(*MuxUDPConnItem).conn)).Close() gcKeys = append(gcKeys, key) - s.log.Printf("gc udp conn %s", v.(*UDPConnItem).connid) + s.log.Printf("gc udp conn %s", v.(*MuxUDPConnItem).connid) } }) for _, k := range gcKeys { @@ -454,7 +455,7 @@ func (s *MuxServer) UDPGCDeamon() { } func (s *MuxServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) { var ( - uc *UDPConnItem + uc *MuxUDPConnItem key = srcAddr.String() ID string err error @@ -475,7 +476,7 @@ func (s *MuxServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) { s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err) return } - uc = &UDPConnItem{ + uc = &MuxUDPConnItem{ conn: &outconn, srcAddr: srcAddr, localAddr: localAddr, @@ -484,7 +485,7 @@ func (s *MuxServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) { s.udpConns.Set(key, uc) s.UDPRevecive(key, ID) } else { - uc = v.(*UDPConnItem) + uc = v.(*MuxUDPConnItem) } go func() { defer func() { @@ -506,7 +507,7 @@ func (s *MuxServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) { func (s *MuxServer) UDPRevecive(key, ID string) { go func() { s.log.Printf("udp conn %s connected", ID) - var uc *UDPConnItem + var uc *MuxUDPConnItem defer func() { if uc != nil { (*uc.conn).Close() @@ -519,7 +520,7 @@ func (s *MuxServer) UDPRevecive(key, ID string) { s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID) return } - uc = v.(*UDPConnItem) + uc = v.(*MuxUDPConnItem) for { _, body, err := utils.ReadUDPPacket(*uc.conn) if err != nil { diff --git a/services/tunnel/tunnel_bridge.go b/services/tunnel/tunnel_bridge.go index 4484707..6c480f7 100644 --- a/services/tunnel/tunnel_bridge.go +++ b/services/tunnel/tunnel_bridge.go @@ -141,6 +141,7 @@ func (s *TunnelBridge) callback(inConn net.Conn) { var buf = make([]byte, 1024) n, _ := inConn.Read(buf) reader := bytes.NewReader(buf[:n]) + //reader := bufio.NewReader(inConn) var connType uint8 diff --git a/services/tunnel/tunnel_client.go b/services/tunnel/tunnel_client.go index 36182c2..323e843 100644 --- a/services/tunnel/tunnel_client.go +++ b/services/tunnel/tunnel_client.go @@ -7,6 +7,7 @@ import ( logger "log" "net" "os" + "strings" "time" "github.com/snail007/goproxy/services" @@ -28,6 +29,15 @@ type TunnelClientArgs struct { Timeout *int Jumper *string } +type ClientUDPConnItem struct { + conn *net.Conn + isActive bool + touchtime int64 + srcAddr *net.UDPAddr + localAddr *net.UDPAddr + udpConn *net.UDPConn + connid string +} type TunnelClient struct { cfg TunnelClientArgs ctrlConn net.Conn @@ -35,6 +45,7 @@ type TunnelClient struct { userConns mapx.ConcurrentMap log *logger.Logger jumper *jumper.Jumper + udpConns mapx.ConcurrentMap } func NewTunnelClient() services.Service { @@ -42,10 +53,12 @@ func NewTunnelClient() services.Service { cfg: TunnelClientArgs{}, userConns: mapx.NewConcurrentMap(), isStop: false, + udpConns: mapx.NewConcurrentMap(), } } func (s *TunnelClient) InitService() (err error) { + s.UDPGCDeamon() return } @@ -133,7 +146,7 @@ func (s *TunnelClient) Start(args interface{}, log *logger.Logger) (err error) { s.log.Printf("read connection signal err: %s, retrying...", err) break } - s.log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr) + //s.log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr) protocol := clientLocalAddr[:3] localAddr := clientLocalAddr[4:] if protocol == "udp" { @@ -240,62 +253,119 @@ func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) { } // s.cm.Add(*s.cfg.Key, ID, &inConn) s.log.Printf("conn %s created", ID) - + var item *ClientUDPConnItem + var body []byte + srcAddr := "" + defer func() { + if item != nil { + (*(*item).conn).Close() + (*item).udpConn.Close() + s.udpConns.Remove(srcAddr) + inConn.Close() + } + }() for { if s.isStop { return } - srcAddr, body, err := utils.ReadUDPPacket(inConn) - if err == io.EOF || err == io.ErrUnexpectedEOF { - s.log.Printf("connection %s released", ID) - utils.CloseConn(&inConn) - break - } else if err != nil { - s.log.Printf("udp packet revecived fail, err: %s", err) - } else { - //s.log.Printf("udp packet revecived:%s,%v", srcAddr, body) - go s.processUDPPacket(&inConn, srcAddr, localAddr, body) + srcAddr, body, err = utils.ReadUDPPacket(inConn) + if err != nil { + if strings.Contains(err.Error(), "n != int(") { + continue + } + if !utils.IsNetDeadlineErr(err) && err != io.EOF { + s.log.Printf("udp packet revecived from bridge fail, err: %s", err) + } + return } - + if v, ok := s.udpConns.Get(srcAddr); !ok { + _srcAddr, _ := net.ResolveUDPAddr("udp", srcAddr) + zeroAddr, _ := net.ResolveUDPAddr("udp", ":") + _localAddr, _ := net.ResolveUDPAddr("udp", localAddr) + c, err := net.DialUDP("udp", zeroAddr, _localAddr) + if err != nil { + s.log.Printf("create local udp conn fail, err : %s", err) + inConn.Close() + return + } + item = &ClientUDPConnItem{ + conn: &inConn, + srcAddr: _srcAddr, + localAddr: _localAddr, + udpConn: c, + connid: ID, + } + s.udpConns.Set(srcAddr, item) + s.UDPRevecive(srcAddr, ID) + } else { + item = v.(*ClientUDPConnItem) + } + (*item).touchtime = time.Now().Unix() + go (*item).udpConn.Write(body) } - // } } -func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr string, body []byte) { - dstAddr, err := net.ResolveUDPAddr("udp", localAddr) - if err != nil { - s.log.Printf("can't resolve address: %s", err) - utils.CloseConn(inConn) - return - } - clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} - conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) - if err != nil { - s.log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) - return - } - conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) - _, err = conn.Write(body) - if err != nil { - s.log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) - return - } - //s.log.Printf("send udp packet to %s success", dstAddr.String()) - buf := make([]byte, 1024) - length, _, err := conn.ReadFromUDP(buf) - if err != nil { - s.log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) - return - } - respBody := buf[0:length] - //s.log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody) - bs := utils.UDPPacket(srcAddr, respBody) - _, err = (*inConn).Write(bs) - if err != nil { - s.log.Printf("send udp response fail ,ERR:%s", err) - utils.CloseConn(inConn) - return - } - //s.log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs) +func (s *TunnelClient) UDPRevecive(key, ID string) { + go func() { + s.log.Printf("udp conn %s connected", ID) + v, ok := s.udpConns.Get(key) + if !ok { + s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID) + return + } + cui := v.(*ClientUDPConnItem) + buf := utils.LeakyBuffer.Get() + defer func() { + utils.LeakyBuffer.Put(buf) + (*cui.conn).Close() + cui.udpConn.Close() + s.udpConns.Remove(key) + s.log.Printf("udp conn %s released", ID) + }() + for { + n, err := cui.udpConn.Read(buf) + if err != nil { + if !utils.IsNetClosedErr(err) { + s.log.Printf("udp conn read udp packet fail , err: %s ", err) + } + return + } + cui.touchtime = time.Now().Unix() + go func() { + (*cui.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) + _, err = (*cui.conn).Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n])) + (*cui.conn).SetWriteDeadline(time.Time{}) + if err != nil { + cui.udpConn.Close() + return + } + }() + } + }() +} +func (s *TunnelClient) UDPGCDeamon() { + gctime := int64(30) + go func() { + if s.isStop { + return + } + timer := time.NewTicker(time.Second) + for { + <-timer.C + gcKeys := []string{} + s.udpConns.IterCb(func(key string, v interface{}) { + if time.Now().Unix()-v.(*ClientUDPConnItem).touchtime > gctime { + (*(v.(*ClientUDPConnItem).conn)).Close() + (v.(*ClientUDPConnItem).udpConn).Close() + gcKeys = append(gcKeys, key) + s.log.Printf("gc udp conn %s", v.(*ClientUDPConnItem).connid) + } + }) + for _, k := range gcKeys { + s.udpConns.Remove(k) + } + gcKeys = nil + } + }() } func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) { var inConn, outConn net.Conn diff --git a/services/tunnel/tunnel_server.go b/services/tunnel/tunnel_server.go index 28737fb..bc1262b 100644 --- a/services/tunnel/tunnel_server.go +++ b/services/tunnel/tunnel_server.go @@ -38,18 +38,17 @@ type TunnelServerArgs struct { } type TunnelServer struct { cfg TunnelServerArgs - udpChn chan UDPItem sc utils.ServerChannel isStop bool udpConn *net.Conn userConns mapx.ConcurrentMap log *logger.Logger jumper *jumper.Jumper + udpConns mapx.ConcurrentMap } type TunnelServerManager struct { cfg TunnelServerArgs - udpChn chan UDPItem serverID string servers []*services.Service log *logger.Logger @@ -58,7 +57,6 @@ type TunnelServerManager struct { func NewTunnelServerManager() services.Service { return &TunnelServerManager{ cfg: TunnelServerArgs{}, - udpChn: make(chan UDPItem, 50000), serverID: utils.Uniqueid(), servers: []*services.Service{}, } @@ -146,17 +144,25 @@ func (s *TunnelServerManager) InitService() (err error) { func NewTunnelServer() services.Service { return &TunnelServer{ cfg: TunnelServerArgs{}, - udpChn: make(chan UDPItem, 50000), isStop: false, userConns: mapx.NewConcurrentMap(), + udpConns: mapx.NewConcurrentMap(), } } -type UDPItem struct { +type TunnelUDPPacketItem struct { packet *[]byte localAddr *net.UDPAddr srcAddr *net.UDPAddr } +type TunnelUDPConnItem struct { + conn *net.Conn + isActive bool + touchtime int64 + srcAddr *net.UDPAddr + localAddr *net.UDPAddr + connid string +} func (s *TunnelServer) StopService() { defer func() { @@ -183,7 +189,7 @@ func (s *TunnelServer) StopService() { } } func (s *TunnelServer) InitService() (err error) { - s.UDPConnDeamon() + s.UDPGCDeamon() return } func (s *TunnelServer) CheckArgs() (err error) { @@ -217,11 +223,7 @@ func (s *TunnelServer) Start(args interface{}, log *logger.Logger) (err error) { s.sc = utils.NewServerChannel(host, p, s.log) if *s.cfg.IsUDP { err = s.sc.ListenUDP(func(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr) { - s.udpChn <- UDPItem{ - packet: &packet, - localAddr: localAddr, - srcAddr: srcAddr, - } + s.UDPSend(packet, localAddr, srcAddr) }) if err != nil { return @@ -348,89 +350,103 @@ func (s *TunnelServer) GetConn() (conn net.Conn, err error) { } return } -func (s *TunnelServer) UDPConnDeamon() { +func (s *TunnelServer) UDPGCDeamon() { + gctime := int64(30) go func() { - defer func() { - if err := recover(); err != nil { - s.log.Printf("udp conn deamon crashed with err : %s \nstack: %s", err, string(debug.Stack())) - } - }() - var outConn net.Conn - // var hb utils.HeartbeatReadWriter - var ID string - // var cmdChn = make(chan bool, 1000) - var err error + if s.isStop { + return + } + timer := time.NewTicker(time.Second) for { - if s.isStop { - return - } - item := <-s.udpChn - RETRY: - if s.isStop { - return - } - if outConn == nil { - for { - if s.isStop { - return - } - outConn, ID, err = s.GetOutConn(CONN_SERVER) - if err != nil { - // cmdChn <- true - outConn = nil - utils.CloseConn(&outConn) - s.log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) - time.Sleep(time.Second * 3) - continue - } else { - go func(outConn net.Conn, ID string) { - if s.udpConn != nil { - (*s.udpConn).Close() - } - s.udpConn = &outConn - for { - if s.isStop { - return - } - srcAddrFromConn, body, err := utils.ReadUDPPacket(outConn) - if err == io.EOF || err == io.ErrUnexpectedEOF { - s.log.Printf("UDP deamon connection %s exited", ID) - break - } - if err != nil { - s.log.Printf("parse revecived udp packet fail, err: %s ,%v", err, body) - continue - } - //s.log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn) - _srcAddr := strings.Split(srcAddrFromConn, ":") - if len(_srcAddr) != 2 { - s.log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn) - continue - } - port, _ := strconv.Atoi(_srcAddr[1]) - dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port} - _, err = s.sc.UDPListener.WriteToUDP(body, dstAddr) - if err != nil { - s.log.Printf("udp response to local %s fail,ERR:%s", srcAddrFromConn, err) - continue - } - //s.log.Printf("udp response to local %s success , %v", srcAddrFromConn, body) - } - }(outConn, ID) - break - } + <-timer.C + gcKeys := []string{} + s.udpConns.IterCb(func(key string, v interface{}) { + if time.Now().Unix()-v.(*TunnelUDPConnItem).touchtime > gctime { + (*(v.(*TunnelUDPConnItem).conn)).Close() + gcKeys = append(gcKeys, key) + s.log.Printf("gc udp conn %s", v.(*TunnelUDPConnItem).connid) } + }) + for _, k := range gcKeys { + s.udpConns.Remove(k) } - outConn.SetWriteDeadline(time.Now().Add(time.Second)) - _, err = outConn.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet)) - outConn.SetWriteDeadline(time.Time{}) - if err != nil { - utils.CloseConn(&outConn) - outConn = nil - s.log.Printf("write udp packet to %s fail ,flush err:%s ,retrying...", *s.cfg.Parent, err) - goto RETRY - } - //s.log.Printf("write packet %v", *item.packet) + gcKeys = nil + } + }() +} +func (s *TunnelServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) { + var ( + uc *TunnelUDPConnItem + key = srcAddr.String() + ID string + err error + outconn net.Conn + ) + v, ok := s.udpConns.Get(key) + if !ok { + outconn, ID, err = s.GetOutConn(CONN_SERVER) + if err != nil { + s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err) + return + } + uc = &TunnelUDPConnItem{ + conn: &outconn, + srcAddr: srcAddr, + localAddr: localAddr, + connid: ID, + } + s.udpConns.Set(key, uc) + s.UDPRevecive(key, ID) + } else { + uc = v.(*TunnelUDPConnItem) + } + go func() { + defer func() { + if e := recover(); e != nil { + (*uc.conn).Close() + s.udpConns.Remove(key) + s.log.Printf("udp sender crashed with error : %s", e) + } + }() + uc.touchtime = time.Now().Unix() + (*uc.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) + _, err = (*uc.conn).Write(utils.UDPPacket(srcAddr.String(), data)) + (*uc.conn).SetWriteDeadline(time.Time{}) + if err != nil { + s.log.Printf("write udp packet to %s fail ,flush err:%s ", *s.cfg.Parent, err) + } + }() +} +func (s *TunnelServer) UDPRevecive(key, ID string) { + go func() { + s.log.Printf("udp conn %s connected", ID) + var uc *TunnelUDPConnItem + defer func() { + if uc != nil { + (*uc.conn).Close() + } + s.udpConns.Remove(key) + s.log.Printf("udp conn %s released", ID) + }() + v, ok := s.udpConns.Get(key) + if !ok { + s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID) + return + } + uc = v.(*TunnelUDPConnItem) + for { + _, body, err := utils.ReadUDPPacket(*uc.conn) + if err != nil { + if strings.Contains(err.Error(), "n != int(") { + continue + } + if err != io.EOF { + s.log.Printf("udp conn read udp packet fail , err: %s ", err) + } + return + } + uc.touchtime = time.Now().Unix() + go s.sc.UDPListener.WriteToUDP(body, uc.srcAddr) } }() } diff --git a/utils/functions.go b/utils/functions.go index 00a4470..a00d334 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -577,7 +577,7 @@ func BuildPacket(packetType uint8, data ...string) []byte { binary.Write(pkg, binary.LittleEndian, packetType) for _, d := range data { bytes := []byte(d) - binary.Write(pkg, binary.LittleEndian, uint16(len(bytes))) + binary.Write(pkg, binary.LittleEndian, uint64(len(bytes))) binary.Write(pkg, binary.LittleEndian, bytes) } return pkg.Bytes()