From 0247c4701d6dd9bcb88e444b6f0ea3249706274e Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Thu, 30 Nov 2017 18:43:31 +0800 Subject: [PATCH] Signed-off-by: arraykeys@gmail.com --- services/mux_bridge.go | 263 +++++++++++++++-------------------------- services/mux_server.go | 89 ++------------ 2 files changed, 103 insertions(+), 249 deletions(-) diff --git a/services/mux_bridge.go b/services/mux_bridge.go index 11ab8dc..32c4b88 100644 --- a/services/mux_bridge.go +++ b/services/mux_bridge.go @@ -7,6 +7,8 @@ import ( "proxy/utils" "strconv" "time" + + "github.com/xtaci/smux" ) type MuxServerConn struct { @@ -17,8 +19,6 @@ type MuxBridge struct { cfg MuxBridgeArgs serverConns utils.ConcurrentMap clientControlConns utils.ConcurrentMap - // cmServer utils.ConnManager - // cmClient utils.ConnManager } func NewMuxBridge() Service { @@ -26,8 +26,6 @@ func NewMuxBridge() Service { cfg: MuxBridgeArgs{}, serverConns: utils.NewConcurrentMap(), clientControlConns: utils.NewConcurrentMap(), - // cmServer: utils.NewConnManager(), - // cmClient: utils.NewConnManager(), } } @@ -52,8 +50,6 @@ func (s *MuxBridge) Start(args interface{}) (err error) { sc := utils.NewServerChannel(host, p) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) { - //log.Printf("connection from %s ", inConn.RemoteAddr()) - reader := bufio.NewReader(inConn) var err error var connType uint8 @@ -64,181 +60,112 @@ func (s *MuxBridge) Start(args interface{}) (err error) { } switch connType { case CONN_SERVER: - var key, ID, clientLocalAddr, serverID string - err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID) + session, err := smux.Server(inConn, nil) if err != nil { - log.Printf("read error,ERR:%s", err) + utils.CloseConn(&inConn) + log.Printf("server underlayer connection error,ERR:%s", err) return } - packet := utils.BuildPacketData(ID, clientLocalAddr, serverID) - log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID) - - //addr := clientLocalAddr + "@" + ID - s.serverConns.Set(ID, MuxServerConn{ - Conn: &inConn, - }) - for { - item, ok := s.clientControlConns.Get(key) - if !ok { - log.Printf("client %s control conn not exists", key) - time.Sleep(time.Second * 3) - continue - } - (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) - _, err := (*item.(*net.Conn)).Write(packet) - (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) - if err != nil { - log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) - time.Sleep(time.Second * 3) - continue - } else { - // s.cmServer.Add(serverID, ID, &inConn) - break - } - } - case CONN_CLIENT: - var key, ID, serverID string - err = utils.ReadPacketData(reader, &key, &ID, &serverID) + conn, err := session.AcceptStream() if err != nil { - log.Printf("read error,ERR:%s", err) + session.Close() + utils.CloseConn(&inConn) return } - log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID) - - serverConnItem, ok := s.serverConns.Get(ID) - if !ok { - inConn.Close() - log.Printf("server conn %s exists", ID) - return - } - serverConn := serverConnItem.(MuxServerConn).Conn - utils.IoBind(*serverConn, inConn, func(err interface{}) { - s.serverConns.Remove(ID) - // s.cmClient.RemoveOne(key, ID) - // s.cmServer.RemoveOne(serverID, ID) - log.Printf("conn %s released", ID) - }) - // s.cmClient.Add(key, ID, &inConn) - log.Printf("conn %s created", ID) - - case CONN_CLIENT_CONTROL: - var key string - err = utils.ReadPacketData(reader, &key) - if err != nil { - log.Printf("read error,ERR:%s", err) - return - } - log.Printf("client control connection, key: %s", key) - if s.clientControlConns.Has(key) { - item, _ := s.clientControlConns.Get(key) - (*item.(*net.Conn)).Close() - } - s.clientControlConns.Set(key, &inConn) - log.Printf("set client %s control conn", key) - - // case CONN_SERVER_HEARBEAT: - // var serverID string - // err = utils.ReadPacketData(reader, &serverID) - // if err != nil { - // log.Printf("read error,ERR:%s", err) - // return - // } - // log.Printf("server heartbeat connection, id: %s", serverID) - // writeDie := make(chan bool) - // readDie := make(chan bool) - // go func() { - // for { - // inConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) - // _, err = inConn.Write([]byte{0x00}) - // inConn.SetWriteDeadline(time.Time{}) - // if err != nil { - // log.Printf("server heartbeat connection write err %s", err) - // break - // } - // time.Sleep(time.Second * 3) - // } - // close(writeDie) - // }() - // go func() { - // for { - // signal := make([]byte, 1) - // inConn.SetReadDeadline(time.Now().Add(time.Second * 6)) - // _, err := inConn.Read(signal) - // inConn.SetReadDeadline(time.Time{}) - // if err != nil { - // log.Printf("server heartbeat connection read err: %s", err) - // break - // } else { - // // log.Printf("heartbeat from server ,id:%s", serverID) - // } - // } - // close(readDie) - // }() - // select { - // case <-readDie: - // case <-writeDie: - // } - // utils.CloseConn(&inConn) - // s.cmServer.Remove(serverID) - // log.Printf("server heartbeat conn %s released", serverID) - // case CONN_CLIENT_HEARBEAT: - // var clientID string - // err = utils.ReadPacketData(reader, &clientID) - // if err != nil { - // log.Printf("read error,ERR:%s", err) - // return - // } - // log.Printf("client heartbeat connection, id: %s", clientID) - // writeDie := make(chan bool) - // readDie := make(chan bool) - // go func() { - // for { - // inConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) - // _, err = inConn.Write([]byte{0x00}) - // inConn.SetWriteDeadline(time.Time{}) - // if err != nil { - // log.Printf("client heartbeat connection write err %s", err) - // break - // } - // time.Sleep(time.Second * 3) - // } - // close(writeDie) - // }() - // go func() { - // for { - // signal := make([]byte, 1) - // inConn.SetReadDeadline(time.Now().Add(time.Second * 6)) - // _, err := inConn.Read(signal) - // inConn.SetReadDeadline(time.Time{}) - // if err != nil { - // log.Printf("client control connection read err: %s", err) - // break - // } else { - // // log.Printf("heartbeat from client ,id:%s", clientID) - // } - // } - // close(readDie) - // }() - // select { - // case <-readDie: - // case <-writeDie: - // } - // utils.CloseConn(&inConn) - // s.cmClient.Remove(clientID) - // if s.clientControlConns.Has(clientID) { - // item, _ := s.clientControlConns.Get(clientID) - // (*item.(*net.Conn)).Close() - // } - // s.clientControlConns.Remove(clientID) - // log.Printf("client heartbeat conn %s released", clientID) + log.Printf("server connection %s", conn.RemoteAddr()) + //s.callback(conn) } + s.callback(inConn) }) if err != nil { return } - log.Printf("proxy on tunnel bridge mode %s", (*sc.Listener).Addr()) + log.Printf("proxy on mux bridge mode %s", (*sc.Listener).Addr()) return } func (s *MuxBridge) Clean() { s.StopService() } +func (s *MuxBridge) callback(inConn net.Conn) { + reader := bufio.NewReader(inConn) + var err error + var connType uint8 + err = utils.ReadPacket(reader, &connType) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + switch connType { + case CONN_SERVER: + var key, ID, clientLocalAddr, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + packet := utils.BuildPacketData(ID, clientLocalAddr, serverID) + log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID) + + //addr := clientLocalAddr + "@" + ID + s.serverConns.Set(ID, MuxServerConn{ + Conn: &inConn, + }) + for { + item, ok := s.clientControlConns.Get(key) + if !ok { + log.Printf("client %s control conn not exists", key) + time.Sleep(time.Second * 3) + continue + } + (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) + _, err := (*item.(*net.Conn)).Write(packet) + (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) + if err != nil { + log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) + time.Sleep(time.Second * 3) + continue + } else { + break + } + } + case CONN_CLIENT: + var key, ID, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID) + + serverConnItem, ok := s.serverConns.Get(ID) + if !ok { + inConn.Close() + log.Printf("server conn %s exists", ID) + return + } + serverConn := serverConnItem.(MuxServerConn).Conn + utils.IoBind(*serverConn, inConn, func(err interface{}) { + s.serverConns.Remove(ID) + // s.cmClient.RemoveOne(key, ID) + // s.cmServer.RemoveOne(serverID, ID) + log.Printf("conn %s released", ID) + }) + // s.cmClient.Add(key, ID, &inConn) + log.Printf("conn %s created", ID) + + case CONN_CLIENT_CONTROL: + var key string + err = utils.ReadPacketData(reader, &key) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client control connection, key: %s", key) + if s.clientControlConns.Has(key) { + item, _ := s.clientControlConns.Get(key) + (*item.(*net.Conn)).Close() + } + s.clientControlConns.Set(key, &inConn) + log.Printf("set client %s control conn", key) + } +} diff --git a/services/mux_server.go b/services/mux_server.go index 3e08193..581ddb6 100644 --- a/services/mux_server.go +++ b/services/mux_server.go @@ -11,12 +11,16 @@ import ( "strconv" "strings" "time" + + "github.com/xtaci/smux" ) type MuxServer struct { - cfg MuxServerArgs - udpChn chan MuxUDPItem - sc utils.ServerChannel + cfg MuxServerArgs + udpChn chan MuxUDPItem + sc utils.ServerChannel + underLayerConn net.Conn + session *smux.Session } type MuxServerManager struct { @@ -100,86 +104,9 @@ func (s *MuxServerManager) CheckArgs() { s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) } func (s *MuxServerManager) InitService() { - // s.InitHeartbeatDeamon() + } -// func (s *MuxServerManager) InitHeartbeatDeamon() { -// log.Printf("heartbeat started") -// go func() { -// var heartbeatConn net.Conn -// var ID string -// for { -// //close all connection -// s.cm.Remove(ID) -// utils.CloseConn(&heartbeatConn) -// heartbeatConn, ID, err := s.GetOutConn(CONN_SERVER_HEARBEAT) -// if err != nil { -// log.Printf("heartbeat connection err: %s, retrying...", err) -// time.Sleep(time.Second * 3) -// utils.CloseConn(&heartbeatConn) -// continue -// } -// log.Printf("heartbeat connection created,id:%s", ID) -// writeDie := make(chan bool) -// readDie := make(chan bool) -// go func() { -// for { -// heartbeatConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) -// _, err = heartbeatConn.Write([]byte{0x00}) -// heartbeatConn.SetWriteDeadline(time.Time{}) -// if err != nil { -// log.Printf("heartbeat connection write err %s", err) -// break -// } -// time.Sleep(time.Second * 3) -// } -// close(writeDie) -// }() -// go func() { -// for { -// signal := make([]byte, 1) -// heartbeatConn.SetReadDeadline(time.Now().Add(time.Second * 6)) -// _, err := heartbeatConn.Read(signal) -// heartbeatConn.SetReadDeadline(time.Time{}) -// if err != nil { -// log.Printf("heartbeat connection read err: %s", err) -// break -// } else { -// // log.Printf("heartbeat from bridge") -// } -// } -// close(readDie) -// }() -// select { -// case <-readDie: -// case <-writeDie: -// } -// } -// }() -// } -func (s *MuxServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) { - outConn, err = s.GetConn() - if err != nil { - log.Printf("connection err: %s", err) - return - } - ID = s.serverID - _, err = outConn.Write(utils.BuildPacket(typ, s.serverID)) - if err != nil { - log.Printf("write connection data err: %s ,retrying...", err) - utils.CloseConn(&outConn) - return - } - return -} -func (s *MuxServerManager) GetConn() (conn net.Conn, err error) { - var _conn tls.Conn - _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) - if err == nil { - conn = net.Conn(&_conn) - } - return -} func NewMuxServer() Service { return &MuxServer{ cfg: MuxServerArgs{},