diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index 53c5d67..3c045ea 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -12,7 +12,8 @@ import ( type ServerConn struct { ClientLocalAddr string //tcp:2.2.22:333@ID - Conn *net.Conn + // Conn *net.Conn + Conn *utils.HeartbeatReadWriter } type TunnelBridge struct { cfg TunnelBridgeArgs @@ -116,9 +117,13 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { switch connType { case CONN_SERVER: + hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { + log.Printf("%s conn %s from server released", key, ID) + s.serverConns.Remove(ID) + }) addr := clientLocalAddr + "@" + ID s.serverConns.Set(ID, ServerConn{ - Conn: &inConn, + Conn: &hb, ClientLocalAddr: addr, }) for { @@ -146,13 +151,14 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { } serverConn := serverConnItem.(ServerConn).Conn // hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) { - // log.Printf("hw err %s", err) + // log.Printf("%s conn %s from client released", key, ID) // hw.Close() // }) - // utils.IoBind(*serverConn, &hw, func(isSrcErr bool, err error) { - utils.IoBind(*serverConn, inConn, func(isSrcErr bool, err error) { - utils.CloseConn(serverConn) + utils.IoBind(serverConn, inConn, func(isSrcErr bool, err error) { + // utils.IoBind(serverConn, inConn, func(isSrcErr bool, err error) { + serverConn.Close() utils.CloseConn(&inConn) + // hw.Close() s.serverConns.Remove(ID) log.Printf("conn %s released", ID) }, func(i int, b bool) {}, 0) diff --git a/services/tunnel_client.go b/services/tunnel_client.go index 2f02185..2fbdd91 100644 --- a/services/tunnel_client.go +++ b/services/tunnel_client.go @@ -142,9 +142,13 @@ func (s *TunnelClient) ServeUDP(localAddr, ID string) { log.Printf("connection %s released", ID) utils.CloseConn(&inConn) break + } else if err != nil { + log.Printf("udp packet revecived fail, err: %s", err) + } else { + log.Printf("udp packet revecived:%s,%v", srcAddr, body) + go s.processUDPPacket(&inConn, srcAddr, localAddr, body) } - //log.Printf("udp packet revecived:%s,%v", srcAddr, body) - go s.processUDPPacket(&inConn, srcAddr, localAddr, body) + } // } } @@ -168,21 +172,22 @@ func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr str return } //log.Printf("send udp packet to %s success", dstAddr.String()) - buf := make([]byte, 512) - len, _, err := conn.ReadFromUDP(buf) + buf := make([]byte, 1024) + length, _, err := conn.ReadFromUDP(buf) if err != nil { log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) return } - respBody := buf[0:len] - //log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody) - _, err = (*inConn).Write(utils.UDPPacket(srcAddr, respBody)) + respBody := buf[0:length] + log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody) + bs := utils.UDPPacket(srcAddr, respBody) + _, err = (*inConn).Write(bs) if err != nil { log.Printf("send udp response fail ,ERR:%s", err) utils.CloseConn(inConn) return } - //log.Printf("send udp response success ,from:%s", dstAddr.String()) + log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs) } func (s *TunnelClient) ServeConn(localAddr, ID string) { var inConn, outConn net.Conn diff --git a/services/tunnel_server.go b/services/tunnel_server.go index ccbe6c6..9965fba 100644 --- a/services/tunnel_server.go +++ b/services/tunnel_server.go @@ -1,7 +1,6 @@ package services import ( - "bufio" "bytes" "crypto/tls" "encoding/binary" @@ -80,8 +79,9 @@ func (s *TunnelServer) Start(args interface{}) (err error) { } }() var outConn net.Conn + var ID string for { - outConn, err = s.GetOutConn("") + outConn, ID, err = s.GetOutConn("") if err != nil { utils.CloseConn(&outConn) log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) @@ -91,13 +91,18 @@ func (s *TunnelServer) Start(args interface{}) (err error) { break } } - utils.IoBind(inConn, outConn, func(isSrcErr bool, err error) { + hb := utils.NewHeartbeatReadWriter(&outConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { + log.Printf("%s conn %s to bridge released", *s.cfg.Key, ID) + hb.Close() + }) + utils.IoBind(inConn, &hb, func(isSrcErr bool, err error) { + //utils.IoBind(inConn, outConn, func(isSrcErr bool, err error) { utils.CloseConn(&outConn) utils.CloseConn(&inConn) - log.Printf("%s conn %s - %s - %s - %s released", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr()) + log.Printf("%s conn %s released", *s.cfg.Key, ID) }, func(i int, b bool) {}, 0) - log.Printf("%s conn %s - %s - %s - %s created", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr()) + log.Printf("%s conn %s created", *s.cfg.Key, ID) }) if err != nil { return @@ -109,7 +114,7 @@ func (s *TunnelServer) Start(args interface{}) (err error) { func (s *TunnelServer) Clean() { s.StopService() } -func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, err error) { +func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, ID string, err error) { outConn, err = s.GetConn() if err != nil { log.Printf("connection err: %s", err) @@ -117,8 +122,10 @@ func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, err error) { } keyBytes := []byte(*s.cfg.Key) keyLength := uint16(len(keyBytes)) - IDBytes := []byte(utils.NewUniqueID().String()) + ID = utils.NewUniqueID().String() + IDBytes := []byte(ID) if id != "" { + ID = id IDBytes = []byte(id) } IDLength := uint16(len(IDBytes)) @@ -159,6 +166,8 @@ func (s *TunnelServer) UDPConnDeamon() { } }() var outConn net.Conn + var hb utils.HeartbeatReadWriter + var ID string var cmdChn = make(chan bool, 1) var err error @@ -167,7 +176,7 @@ func (s *TunnelServer) UDPConnDeamon() { RETRY: if outConn == nil { for { - outConn, err = s.GetOutConn("") + outConn, ID, err = s.GetOutConn("") if err != nil { cmdChn <- true outConn = nil @@ -176,19 +185,23 @@ func (s *TunnelServer) UDPConnDeamon() { time.Sleep(time.Second * 3) continue } else { - go func(outConn net.Conn) { + hb = utils.NewHeartbeatReadWriter(&outConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { + log.Printf("%s conn %s to bridge released", *s.cfg.Key, ID) + hb.Close() + }) + go func(outConn net.Conn, hb utils.HeartbeatReadWriter, ID string) { go func() { <-cmdChn outConn.Close() }() for { - srcAddrFromConn, body, err := utils.ReadUDPPacket(bufio.NewReader(outConn)) + srcAddrFromConn, body, err := utils.ReadUDPPacket(&hb) if err == io.EOF || err == io.ErrUnexpectedEOF { - log.Printf("udp connection deamon exited, %s -> %s", outConn.LocalAddr(), outConn.RemoteAddr()) + log.Printf("UDP deamon connection %s exited", ID) break } if err != nil { - log.Printf("parse revecived udp packet fail, err: %s", err) + log.Printf("parse revecived udp packet fail, err: %s ,%v", err, body) continue } //log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn) @@ -204,25 +217,27 @@ func (s *TunnelServer) UDPConnDeamon() { log.Printf("udp response to local %s fail,ERR:%s", srcAddrFromConn, err) continue } - //log.Printf("udp response to local %s success", srcAddrFromConn) + //log.Printf("udp response to local %s success , %v", srcAddrFromConn, body) } - }(outConn) + }(outConn, hb, ID) break } } } outConn.SetWriteDeadline(time.Now().Add(time.Second)) - writer := bufio.NewWriter(outConn) - writer.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet)) - err := writer.Flush() + _, err = hb.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet)) + // writer := bufio.NewWriter(outConn) + // writer.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet)) + // err := writer.Flush() + outConn.SetWriteDeadline(time.Time{}) if err != nil { utils.CloseConn(&outConn) outConn = nil log.Printf("write udp packet to %s fail ,flush err:%s ,retrying...", *s.cfg.Parent, err) goto RETRY } - outConn.SetWriteDeadline(time.Time{}) - //log.Printf("write packet %v", *item.packet) + + log.Printf("write packet %v", *item.packet) } }() } diff --git a/utils/functions.go b/utils/functions.go index fcc80bb..4f9eff0 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -1,6 +1,7 @@ package utils import ( + "bufio" "bytes" "crypto/tls" "crypto/x509" @@ -264,6 +265,7 @@ func UDPPacket(srcAddr string, packet []byte) []byte { addrBytes := []byte(srcAddr) addrLength := uint16(len(addrBytes)) bodyLength := uint16(len(packet)) + //log.Printf("build packet : addr len %d, body len %d", addrLength, bodyLength) pkg := new(bytes.Buffer) binary.Write(pkg, binary.LittleEndian, addrLength) binary.Write(pkg, binary.LittleEndian, addrBytes) @@ -271,8 +273,8 @@ func UDPPacket(srcAddr string, packet []byte) []byte { binary.Write(pkg, binary.LittleEndian, packet) return pkg.Bytes() } -func ReadUDPPacket(reader io.Reader) (srcAddr string, packet []byte, err error) { - // reader := bufio.NewReader(_reader) +func ReadUDPPacket(_reader io.Reader) (srcAddr string, packet []byte, err error) { + reader := bufio.NewReader(_reader) var addrLength uint16 var bodyLength uint16 err = binary.Read(reader, binary.LittleEndian, &addrLength) @@ -285,12 +287,14 @@ func ReadUDPPacket(reader io.Reader) (srcAddr string, packet []byte, err error) return } if n != int(addrLength) { + err = fmt.Errorf("n != int(addrLength), %d,%d", n, addrLength) return } srcAddr = string(_srcAddr) err = binary.Read(reader, binary.LittleEndian, &bodyLength) if err != nil { + return } packet = make([]byte, bodyLength) @@ -299,6 +303,7 @@ func ReadUDPPacket(reader io.Reader) (srcAddr string, packet []byte, err error) return } if n != int(bodyLength) { + err = fmt.Errorf("n != int(bodyLength), %d,%d", n, bodyLength) return } return diff --git a/utils/structs.go b/utils/structs.go index 4175b6d..57303e7 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -469,22 +469,30 @@ type HeartbeatData struct { Error error } type HeartbeatReadWriter struct { - conn *net.Conn - rchn chan HeartbeatData + conn *net.Conn + // rchn chan HeartbeatData l *sync.Mutex dur int errHandler func(err error, hb *HeartbeatReadWriter) once *sync.Once + datachn chan byte + // rbuf bytes.Buffer + // signal chan bool + rerrchn chan error } func NewHeartbeatReadWriter(conn *net.Conn, dur int, fn func(err error, hb *HeartbeatReadWriter)) (hrw HeartbeatReadWriter) { hrw = HeartbeatReadWriter{ - conn: conn, - l: &sync.Mutex{}, - dur: dur, - rchn: make(chan HeartbeatData, 10000), + conn: conn, + l: &sync.Mutex{}, + dur: dur, + // rchn: make(chan HeartbeatData, 10000), + // signal: make(chan bool, 1), errHandler: fn, + datachn: make(chan byte, 4*1024), once: &sync.Once{}, + rerrchn: make(chan error, 1), + // rbuf: bytes.Buffer{}, } hrw.heartbeat() hrw.reader() @@ -499,15 +507,25 @@ func (rw *HeartbeatReadWriter) reader() { //log.Printf("heartbeat read started") for { n, data, err := rw.read() - log.Printf("n:%d , data:%s ,err:%s", n, string(data), err) - if n >= 0 { - rw.rchn <- HeartbeatData{ - Data: data, - Error: err, - N: n, + if n == -1 { + continue + } + //log.Printf("n:%d , data:%s ,err:%s", n, string(data), err) + if err == nil { + //fmt.Printf("write data %s\n", string(data)) + for _, b := range data { + rw.datachn <- b } } if err != nil { + //log.Printf("heartbeat reader err: %s", err) + select { + case rw.rerrchn <- err: + default: + } + rw.once.Do(func() { + rw.errHandler(err, rw) + }) break } } @@ -515,13 +533,6 @@ func (rw *HeartbeatReadWriter) reader() { }() } func (rw *HeartbeatReadWriter) read() (n int, data []byte, err error) { - defer func() { - if err != nil { - rw.once.Do(func() { - rw.errHandler(err, rw) - }) - } - }() var typ uint8 err = binary.Read((*rw.conn), binary.LittleEndian, &typ) if err != nil { @@ -534,13 +545,18 @@ func (rw *HeartbeatReadWriter) read() (n int, data []byte, err error) { } var dataLength uint32 binary.Read((*rw.conn), binary.LittleEndian, &dataLength) - data = make([]byte, dataLength) + _data := make([]byte, dataLength) // log.Printf("dataLength:%d , data:%s", dataLength, string(data)) - n, err = (*rw.conn).Read(data) + n, err = (*rw.conn).Read(_data) //log.Printf("n:%d , data:%s ,err:%s", n, string(data), err) if err != nil { return } + if uint32(n) != dataLength { + err = fmt.Errorf("read short data body") + return + } + data = _data[:n] return } func (rw *HeartbeatReadWriter) heartbeat() { @@ -555,13 +571,11 @@ func (rw *HeartbeatReadWriter) heartbeat() { _, err := (*rw.conn).Write([]byte{0}) rw.l.Unlock() if err != nil { - if rw.errHandler != nil { - //log.Printf("heartbeat err: %s", err) - rw.once.Do(func() { - rw.errHandler(err, rw) - }) - break - } + //log.Printf("heartbeat err: %s", err) + rw.once.Do(func() { + rw.errHandler(err, rw) + }) + break } else { // log.Printf("heartbeat send ok") } @@ -571,12 +585,18 @@ func (rw *HeartbeatReadWriter) heartbeat() { }() } func (rw *HeartbeatReadWriter) Read(p []byte) (n int, err error) { - item := <-rw.rchn - //log.Println(item) - if item.N > 0 { - copy(p, item.Data) + data := make([]byte, cap(p)) + for i := 0; i < cap(p); i++ { + data[i] = <-rw.datachn + n++ + //fmt.Printf("read %d %v\n", i, data[:n]) + if len(rw.datachn) == 0 { + n = i + 1 + copy(p, data[:n]) + return + } } - return item.N, item.Error + return } func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) { defer rw.l.Unlock() @@ -585,6 +605,10 @@ func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) { binary.Write(pkg, binary.LittleEndian, uint8(1)) binary.Write(pkg, binary.LittleEndian, uint32(len(p))) binary.Write(pkg, binary.LittleEndian, p) - n, err = (*rw.conn).Write(pkg.Bytes()) + bs := pkg.Bytes() + n, err = (*rw.conn).Write(bs) + if err == nil { + n = len(p) + } return }