From 96cd7a2b63fbc49458a202c2266136f742944326 Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Fri, 20 Oct 2017 16:36:43 +0800 Subject: [PATCH 1/2] Signed-off-by: arraykeys@gmail.com --- install_auto.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install_auto.sh b/install_auto.sh index 56eaa4b..b48a25d 100755 --- a/install_auto.sh +++ b/install_auto.sh @@ -6,7 +6,7 @@ fi mkdir /tmp/proxy cd /tmp/proxy wget https://github.com/reddec/monexec/releases/download/v0.1.1/monexec_0.1.1_linux_amd64.tar.gz -wget https://github.com/snail007/goproxy/releases/download/v3.4/proxy-linux-amd64.tar.gz +wget https://github.com/snail007/goproxy/releases/download/v3.3/proxy-linux-amd64.tar.gz # install monexec tar zxvf monexec_0.1.1_linux_amd64.tar.gz From 078acaa0e88f5fbc4a4d0276dd901878f78515b5 Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Mon, 23 Oct 2017 16:28:10 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=86=85=E7=BD=91?= =?UTF-8?q?=E7=A9=BF=E9=80=8F=E5=BF=83=E8=B7=B3=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: arraykeys@gmail.com --- services/args.go | 16 ++-- services/tunnel_bridge.go | 154 ++++++++++++++++++++---------------- services/tunnel_client.go | 67 +++++++--------- services/tunnel_server.go | 161 +++++++++++++++++++++++++++----------- utils/functions.go | 62 +++++++++++++++ utils/structs.go | 61 +++++++++++++++ 6 files changed, 365 insertions(+), 156 deletions(-) diff --git a/services/args.go b/services/args.go index 3b91f26..497de64 100644 --- a/services/args.go +++ b/services/args.go @@ -6,13 +6,14 @@ import "golang.org/x/crypto/ssh" // t := tcp.Flag("tcp-timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int() const ( - TYPE_TCP = "tcp" - TYPE_UDP = "udp" - TYPE_HTTP = "http" - TYPE_TLS = "tls" - CONN_CONTROL = uint8(1) - CONN_SERVER = uint8(2) - CONN_CLIENT = uint8(3) + TYPE_TCP = "tcp" + TYPE_UDP = "udp" + TYPE_HTTP = "http" + TYPE_TLS = "tls" + CONN_CLIENT_CONTROL = uint8(1) + CONN_SERVER_CONTROL = uint8(2) + CONN_SERVER = uint8(3) + CONN_CLIENT = uint8(4) ) type TunnelServerArgs struct { @@ -27,6 +28,7 @@ type TunnelServerArgs struct { Remote *string Timeout *int Route *[]string + Mgr *TunnelServerManager } type TunnelClientArgs struct { Parent *string diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index ab9bd57..853eec6 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -2,7 +2,6 @@ package services import ( "bufio" - "encoding/binary" "log" "net" "proxy/utils" @@ -11,13 +10,15 @@ import ( ) type ServerConn struct { - ClientLocalAddr string //tcp:2.2.22:333@ID - Conn *net.Conn + //ClientLocalAddr string //tcp:2.2.22:333@ID + Conn *net.Conn } type TunnelBridge struct { cfg TunnelBridgeArgs serverConns utils.ConcurrentMap clientControlConns utils.ConcurrentMap + cmServer utils.ConnManager + cmClient utils.ConnManager } func NewTunnelBridge() Service { @@ -25,6 +26,8 @@ func NewTunnelBridge() Service { cfg: TunnelBridgeArgs{}, serverConns: utils.NewConcurrentMap(), clientControlConns: utils.NewConcurrentMap(), + cmServer: utils.NewConnManager(), + cmClient: utils.NewConnManager(), } } @@ -52,73 +55,27 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { //log.Printf("connection from %s ", inConn.RemoteAddr()) reader := bufio.NewReader(inConn) + var err error var connType uint8 - err = binary.Read(reader, binary.LittleEndian, &connType) + err = utils.ReadPacket(reader, &connType) if err != nil { - utils.CloseConn(&inConn) + log.Printf("read error,ERR:%s", err) return } - //log.Printf("conn type %d", connType) - - var key, clientLocalAddr, ID string - var connTypeStrMap = map[uint8]string{CONN_SERVER: "server", CONN_CLIENT: "client", CONN_CONTROL: "client"} - var keyLength uint16 - err = binary.Read(reader, binary.LittleEndian, &keyLength) - if err != nil { - return - } - - _key := make([]byte, keyLength) - n, err := reader.Read(_key) - if err != nil { - return - } - if n != int(keyLength) { - return - } - key = string(_key) - - if connType != CONN_CONTROL { - var IDLength uint16 - err = binary.Read(reader, binary.LittleEndian, &IDLength) - if err != nil { - return - } - _id := make([]byte, IDLength) - n, err := reader.Read(_id) - if err != nil { - return - } - if n != int(IDLength) { - return - } - ID = string(_id) - - if connType == CONN_SERVER { - var addrLength uint16 - err = binary.Read(reader, binary.LittleEndian, &addrLength) - if err != nil { - return - } - _addr := make([]byte, addrLength) - n, err = reader.Read(_addr) - if err != nil { - return - } - if n != int(addrLength) { - return - } - clientLocalAddr = string(_addr) - } - } - log.Printf("connection from %s , key: %s , id: %s", connTypeStrMap[connType], key, ID) - switch connType { case CONN_SERVER: - addr := clientLocalAddr + "@" + ID + var key, ID, clientLocalAddr, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + packet := utils.BuildPacketData(ID, clientLocalAddr, serverID) + log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID) + + //addr := clientLocalAddr + "@" + ID s.serverConns.Set(ID, ServerConn{ - Conn: &inConn, - ClientLocalAddr: addr, + Conn: &inConn, }) for { item, ok := s.clientControlConns.Get(key) @@ -128,17 +85,26 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { continue } (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) - _, err := (*item.(*net.Conn)).Write([]byte(addr)) + _, err := (*item.(*net.Conn)).Write(packet) (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) if err != nil { log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) time.Sleep(time.Second * 3) continue } else { + s.cmServer.Add(serverID, ID, &inConn) break } } case CONN_CLIENT: + var key, ID, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID) + serverConnItem, ok := s.serverConns.Get(ID) if !ok { inConn.Close() @@ -147,15 +113,24 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { } serverConn := serverConnItem.(ServerConn).Conn utils.IoBind(*serverConn, inConn, func(err error) { - (*serverConn).Close() utils.CloseConn(&inConn) s.serverConns.Remove(ID) + s.cmClient.RemoveOne(key, ID) + s.cmServer.RemoveOne(serverID, ID) log.Printf("conn %s released", ID) }, func(i int, b bool) {}, 0) + s.cmClient.Add(key, ID, &inConn) log.Printf("conn %s created", ID) - case CONN_CONTROL: + case CONN_CLIENT_CONTROL: + var key string + err = utils.ReadPacketData(reader, &key) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client control connection, key: %s", key) if s.clientControlConns.Has(key) { item, _ := s.clientControlConns.Get(key) (*item.(*net.Conn)).Close() @@ -168,14 +143,59 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { _, err = inConn.Read(b) if err != nil { inConn.Close() - s.serverConns.Remove(ID) log.Printf("%s control conn from client released", key) + s.cmClient.Remove(key) break } else { //log.Printf("%s heartbeat from client", key) } } }() + case CONN_SERVER_CONTROL: + var serverID string + err = utils.ReadPacketData(reader, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("server control connection, id: %s", serverID) + writeDie := make(chan bool) + readDie := make(chan bool) + go func() { + for { + inConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) + _, err = inConn.Write([]byte{0x00}) + inConn.SetWriteDeadline(time.Time{}) + if err != nil { + log.Printf("control connection write err %s", err) + break + } + time.Sleep(time.Second * 3) + } + close(writeDie) + }() + go func() { + for { + signal := make([]byte, 1) + inConn.SetReadDeadline(time.Now().Add(time.Second * 10)) + _, err := inConn.Read(signal) + inConn.SetReadDeadline(time.Time{}) + if err != nil { + log.Printf("control connection read err: %s", err) + break + } else { + // log.Printf("heartbeat from server ,id:%s", ID) + } + } + close(readDie) + }() + select { + case <-readDie: + case <-writeDie: + } + utils.CloseConn(&inConn) + s.cmServer.Remove(serverID) + log.Printf("server control conn %s released", serverID) } }) if err != nil { diff --git a/services/tunnel_client.go b/services/tunnel_client.go index a77ad04..b48e507 100644 --- a/services/tunnel_client.go +++ b/services/tunnel_client.go @@ -1,25 +1,24 @@ package services import ( - "bytes" "crypto/tls" - "encoding/binary" "fmt" "io" "log" "net" "proxy/utils" - "strings" "time" ) type TunnelClient struct { cfg TunnelClientArgs + cm utils.ConnManager } func NewTunnelClient() Service { return &TunnelClient{ cfg: TunnelClientArgs{}, + cm: utils.NewConnManager(), } } @@ -37,14 +36,20 @@ func (s *TunnelClient) CheckArgs() { s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) } func (s *TunnelClient) StopService() { + s.cm.RemoveAll() } func (s *TunnelClient) Start(args interface{}) (err error) { s.cfg = args.(TunnelClientArgs) s.CheckArgs() s.InitService() log.Printf("proxy on tunnel client mode") + var ctrlConn net.Conn for { - ctrlConn, err := s.GetInConn(CONN_CONTROL, "") + //close all conn + s.cm.Remove(*s.cfg.Key) + utils.CloseConn(&ctrlConn) + + ctrlConn, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key) if err != nil { log.Printf("control connection err: %s, retrying...", err) time.Sleep(time.Second * 3) @@ -53,6 +58,9 @@ func (s *TunnelClient) Start(args interface{}) (err error) { } go func() { for { + if ctrlConn == nil { + break + } ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) _, err = ctrlConn.Write([]byte{0x00}) ctrlConn.SetWriteDeadline(time.Time{}) @@ -65,23 +73,20 @@ func (s *TunnelClient) Start(args interface{}) (err error) { } }() for { - signal := make([]byte, 50) - n, err := ctrlConn.Read(signal) + var ID, clientLocalAddr, serverID string + err = utils.ReadPacketData(ctrlConn, &ID, &clientLocalAddr, &serverID) if err != nil { utils.CloseConn(&ctrlConn) log.Printf("read connection signal err: %s, retrying...", err) break } - addr := string(signal[:n]) - log.Printf("signal revecived:%s", addr) - protocol := addr[:3] - atIndex := strings.Index(addr, "@") - ID := addr[atIndex+1:] - localAddr := addr[4:atIndex] + log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr) + protocol := clientLocalAddr[:3] + localAddr := clientLocalAddr[4:] if protocol == "udp" { - go s.ServeUDP(localAddr, ID) + go s.ServeUDP(localAddr, ID, serverID) } else { - go s.ServeConn(localAddr, ID) + go s.ServeConn(localAddr, ID, serverID) } } } @@ -89,25 +94,13 @@ func (s *TunnelClient) Start(args interface{}) (err error) { func (s *TunnelClient) Clean() { s.StopService() } -func (s *TunnelClient) GetInConn(typ uint8, ID string) (outConn net.Conn, err error) { +func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, err error) { outConn, err = s.GetConn() if err != nil { err = fmt.Errorf("connection err: %s", err) return } - keyBytes := []byte(*s.cfg.Key) - keyLength := uint16(len(keyBytes)) - pkg := new(bytes.Buffer) - binary.Write(pkg, binary.LittleEndian, typ) - binary.Write(pkg, binary.LittleEndian, keyLength) - binary.Write(pkg, binary.LittleEndian, keyBytes) - if ID != "" { - IDBytes := []byte(ID) - IDLength := uint16(len(IDBytes)) - binary.Write(pkg, binary.LittleEndian, IDLength) - binary.Write(pkg, binary.LittleEndian, IDBytes) - } - _, err = outConn.Write(pkg.Bytes()) + _, err = outConn.Write(utils.BuildPacket(typ, data...)) if err != nil { err = fmt.Errorf("write connection data err: %s ,retrying...", err) utils.CloseConn(&outConn) @@ -123,12 +116,13 @@ func (s *TunnelClient) GetConn() (conn net.Conn, err error) { } return } -func (s *TunnelClient) ServeUDP(localAddr, ID string) { +func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) { var inConn net.Conn var err error // for { for { - inConn, err = s.GetInConn(CONN_CLIENT, ID) + s.cm.RemoveOne(*s.cfg.Key, ID) + inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) if err != nil { utils.CloseConn(&inConn) log.Printf("connection err: %s, retrying...", err) @@ -138,13 +132,10 @@ func (s *TunnelClient) ServeUDP(localAddr, ID string) { break } } + s.cm.Add(*s.cfg.Key, ID, &inConn) log.Printf("conn %s created", ID) - // hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) { - // log.Printf("hw err %s", err) - // hw.Close() - // }) + for { - // srcAddr, body, err := utils.ReadUDPPacket(&hw) srcAddr, body, err := utils.ReadUDPPacket(inConn) if err == io.EOF || err == io.ErrUnexpectedEOF { log.Printf("connection %s released", ID) @@ -197,11 +188,11 @@ func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr str } //log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs) } -func (s *TunnelClient) ServeConn(localAddr, ID string) { +func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) { var inConn, outConn net.Conn var err error for { - inConn, err = s.GetInConn(CONN_CLIENT, ID) + inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) if err != nil { utils.CloseConn(&inConn) log.Printf("connection err: %s, retrying...", err) @@ -236,6 +227,8 @@ func (s *TunnelClient) ServeConn(localAddr, ID string) { log.Printf("conn %s released", ID) utils.CloseConn(&inConn) utils.CloseConn(&outConn) + s.cm.RemoveOne(*s.cfg.Key, ID) }, func(i int, b bool) {}, 0) + s.cm.Add(*s.cfg.Key, ID, &inConn) log.Printf("conn %s created", ID) } diff --git a/services/tunnel_server.go b/services/tunnel_server.go index f581a86..aa3c66f 100644 --- a/services/tunnel_server.go +++ b/services/tunnel_server.go @@ -1,9 +1,7 @@ package services import ( - "bytes" "crypto/tls" - "encoding/binary" "fmt" "io" "log" @@ -22,24 +20,33 @@ type TunnelServer struct { } type TunnelServerManager struct { - cfg TunnelServerArgs - udpChn chan UDPItem - sc utils.ServerChannel + cfg TunnelServerArgs + udpChn chan UDPItem + sc utils.ServerChannel + serverID string + cm utils.ConnManager } func NewTunnelServerManager() Service { return &TunnelServerManager{ - cfg: TunnelServerArgs{}, - udpChn: make(chan UDPItem, 50000), + cfg: TunnelServerArgs{}, + udpChn: make(chan UDPItem, 50000), + serverID: utils.Uniqueid(), + cm: utils.NewConnManager(), } } func (s *TunnelServerManager) Start(args interface{}) (err error) { s.cfg = args.(TunnelServerArgs) + s.CheckArgs() if *s.cfg.Parent != "" { log.Printf("use tls parent %s", *s.cfg.Parent) } else { log.Fatalf("parent required") } + + s.InitService() + + log.Printf("server id: %s", s.serverID) //log.Printf("route:%v", *s.cfg.Route) for _, _info := range *s.cfg.Route { IsUDP := *s.cfg.IsUDP @@ -71,6 +78,7 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) { Remote: &remote, Key: &KEY, Timeout: s.cfg.Timeout, + Mgr: s, }) if err != nil { @@ -80,7 +88,95 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) { return } func (s *TunnelServerManager) Clean() { - + s.StopService() +} +func (s *TunnelServerManager) StopService() { + s.cm.RemoveAll() +} +func (s *TunnelServerManager) CheckArgs() { + if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { + log.Fatalf("cert and key file required") + } + s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) +} +func (s *TunnelServerManager) InitService() { + s.InitControlDeamon() +} +func (s *TunnelServerManager) InitControlDeamon() { + go func() { + var ctrlConn net.Conn + var ID string + for { + //close all connection + s.cm.Remove(ID) + utils.CloseConn(&ctrlConn) + ctrlConn, ID, err := s.GetOutConn(CONN_SERVER_CONTROL) + if err != nil { + log.Printf("control connection err: %s, retrying...", err) + time.Sleep(time.Second * 3) + utils.CloseConn(&ctrlConn) + continue + } + log.Printf("control connection created,id:%s", ID) + writeDie := make(chan bool) + readDie := make(chan bool) + go func() { + for { + ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) + _, err = ctrlConn.Write([]byte{0x00}) + ctrlConn.SetWriteDeadline(time.Time{}) + if err != nil { + log.Printf("control connection write err %s", err) + break + } + time.Sleep(time.Second * 3) + } + close(writeDie) + }() + go func() { + for { + signal := make([]byte, 1) + ctrlConn.SetReadDeadline(time.Now().Add(time.Second * 10)) + _, err := ctrlConn.Read(signal) + ctrlConn.SetReadDeadline(time.Time{}) + if err != nil { + log.Printf("control connection read err: %s", err) + break + } else { + // log.Printf("heartbeat from bridge") + } + } + close(readDie) + }() + select { + case <-readDie: + case <-writeDie: + } + } + }() +} +func (s *TunnelServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) { + outConn, err = s.GetConn() + if err != nil { + log.Printf("connection err: %s", err) + return + } + ID = s.serverID + _, err = outConn.Write(utils.BuildPacket(typ, s.serverID)) + if err != nil { + log.Printf("write connection data err: %s ,retrying...", err) + utils.CloseConn(&outConn) + return + } + return +} +func (s *TunnelServerManager) GetConn() (conn net.Conn, err error) { + var _conn tls.Conn + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + if err == nil { + conn = net.Conn(&_conn) + } + return } func NewTunnelServer() Service { return &TunnelServer{ @@ -102,13 +198,8 @@ func (s *TunnelServer) CheckArgs() { if *s.cfg.Remote == "" { log.Fatalf("remote required") } - if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { - log.Fatalf("cert and key file required") - } - s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) -} -func (s *TunnelServer) StopService() { } + func (s *TunnelServer) Start(args interface{}) (err error) { s.cfg = args.(TunnelServerArgs) s.CheckArgs() @@ -138,7 +229,7 @@ func (s *TunnelServer) Start(args interface{}) (err error) { var outConn net.Conn var ID string for { - outConn, ID, err = s.GetOutConn("") + outConn, ID, err = s.GetOutConn(CONN_SERVER) if err != nil { utils.CloseConn(&outConn) log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) @@ -148,17 +239,14 @@ func (s *TunnelServer) Start(args interface{}) (err error) { break } } - // hb := utils.NewHeartbeatReadWriter(&outConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { - // log.Printf("%s conn %s to bridge released", *s.cfg.Key, ID) - // hb.Close() - // }) - // utils.IoBind(inConn, &hb, func(err error) { utils.IoBind(inConn, outConn, func(err error) { utils.CloseConn(&outConn) utils.CloseConn(&inConn) + s.cfg.Mgr.cm.RemoveOne(s.cfg.Mgr.serverID, ID) log.Printf("%s conn %s released", *s.cfg.Key, ID) }, func(i int, b bool) {}, 0) - + //add conn + s.cfg.Mgr.cm.Add(s.cfg.Mgr.serverID, ID, &inConn) log.Printf("%s conn %s created", *s.cfg.Key, ID) }) if err != nil { @@ -169,37 +257,20 @@ func (s *TunnelServer) Start(args interface{}) (err error) { return } func (s *TunnelServer) Clean() { - s.StopService() + } -func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, ID string, err error) { +func (s *TunnelServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) { outConn, err = s.GetConn() if err != nil { log.Printf("connection err: %s", err) return } - keyBytes := []byte(*s.cfg.Key) - keyLength := uint16(len(keyBytes)) - ID = utils.Uniqueid() - IDBytes := []byte(ID) - if id != "" { - ID = id - IDBytes = []byte(id) - } - IDLength := uint16(len(IDBytes)) - remoteAddr := []byte("tcp:" + *s.cfg.Remote) + remoteAddr := "tcp:" + *s.cfg.Remote if *s.cfg.IsUDP { - remoteAddr = []byte("udp:" + *s.cfg.Remote) + remoteAddr = "udp:" + *s.cfg.Remote } - remoteAddrLength := uint16(len(remoteAddr)) - pkg := new(bytes.Buffer) - binary.Write(pkg, binary.LittleEndian, CONN_SERVER) - binary.Write(pkg, binary.LittleEndian, keyLength) - binary.Write(pkg, binary.LittleEndian, keyBytes) - binary.Write(pkg, binary.LittleEndian, IDLength) - binary.Write(pkg, binary.LittleEndian, IDBytes) - binary.Write(pkg, binary.LittleEndian, remoteAddrLength) - binary.Write(pkg, binary.LittleEndian, remoteAddr) - _, err = outConn.Write(pkg.Bytes()) + ID = utils.Uniqueid() + _, err = outConn.Write(utils.BuildPacket(typ, *s.cfg.Key, ID, remoteAddr, s.cfg.Mgr.serverID)) if err != nil { log.Printf("write connection data err: %s ,retrying...", err) utils.CloseConn(&outConn) @@ -232,7 +303,7 @@ func (s *TunnelServer) UDPConnDeamon() { RETRY: if outConn == nil { for { - outConn, ID, err = s.GetOutConn("") + outConn, ID, err = s.GetOutConn(CONN_SERVER) if err != nil { // cmdChn <- true outConn = nil diff --git a/utils/functions.go b/utils/functions.go index 0382c1f..3bd131b 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -315,6 +315,68 @@ func Uniqueid() string { s := fmt.Sprintf("%d", src.Int63()) return s[len(s)-5:len(s)-1] + fmt.Sprintf("%d", uint64(time.Now().UnixNano()))[8:] } +func ReadData(r io.Reader) (data string, err error) { + var len uint16 + err = binary.Read(r, binary.LittleEndian, &len) + if err != nil { + return + } + var n int + _data := make([]byte, len) + n, err = r.Read(_data) + if err != nil { + return + } + if n != int(len) { + err = fmt.Errorf("error data len") + return + } + data = string(_data) + return +} +func ReadPacketData(r io.Reader, data ...*string) (err error) { + for _, d := range data { + *d, err = ReadData(r) + if err != nil { + return + } + } + return +} +func ReadPacket(r io.Reader, typ *uint8, data ...*string) (err error) { + var connType uint8 + err = binary.Read(r, binary.LittleEndian, &connType) + if err != nil { + return + } + *typ = connType + for _, d := range data { + *d, err = ReadData(r) + if err != nil { + return + } + } + return +} +func BuildPacket(typ uint8, data ...string) []byte { + pkg := new(bytes.Buffer) + binary.Write(pkg, binary.LittleEndian, typ) + for _, d := range data { + bytes := []byte(d) + binary.Write(pkg, binary.LittleEndian, uint16(len(bytes))) + binary.Write(pkg, binary.LittleEndian, bytes) + } + return pkg.Bytes() +} +func BuildPacketData(data ...string) []byte { + pkg := new(bytes.Buffer) + for _, d := range data { + bytes := []byte(d) + binary.Write(pkg, binary.LittleEndian, uint16(len(bytes))) + binary.Write(pkg, binary.LittleEndian, bytes) + } + return pkg.Bytes() +} func SubStr(str string, start, end int) string { if len(str) == 0 { return "" diff --git a/utils/structs.go b/utils/structs.go index 56abac8..858ae7b 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -617,3 +617,64 @@ func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) { } return } + +type ConnManager struct { + pool ConcurrentMap + l *sync.Mutex +} + +func NewConnManager() ConnManager { + cm := ConnManager{ + pool: NewConcurrentMap(), + l: &sync.Mutex{}, + } + return cm +} +func (cm *ConnManager) Add(key, ID string, conn *net.Conn) { + cm.pool.Upsert(key, nil, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} { + var conns ConcurrentMap + if !exist { + conns = NewConcurrentMap() + } else { + conns = valueInMap.(ConcurrentMap) + } + if conns.Has(ID) { + v, _ := conns.Get(ID) + (*v.(*net.Conn)).Close() + } + conns.Set(ID, conn) + log.Printf("%s conn added", key) + return conns + }) +} +func (cm *ConnManager) Remove(key string) { + var conns ConcurrentMap + if v, ok := cm.pool.Get(key); ok { + conns = v.(ConcurrentMap) + conns.IterCb(func(key string, v interface{}) { + CloseConn(v.(*net.Conn)) + }) + log.Printf("%s conns closed", key) + } + cm.pool.Remove(key) +} +func (cm *ConnManager) RemoveOne(key string, ID string) { + defer cm.l.Unlock() + cm.l.Lock() + var conns ConcurrentMap + if v, ok := cm.pool.Get(key); ok { + conns = v.(ConcurrentMap) + if conns.Has(ID) { + v, _ := conns.Get(ID) + (*v.(*net.Conn)).Close() + conns.Remove(ID) + cm.pool.Set(key, conns) + log.Printf("%s %s conn closed", key, ID) + } + } +} +func (cm *ConnManager) RemoveAll() { + for _, k := range cm.pool.Keys() { + cm.Remove(k) + } +}