diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index b3d4c7d..ba4d0a8 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -70,95 +70,7 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { p, _ := strconv.Atoi(port) sc := utils.NewServerChannel(host, p) - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, func(inConn net.Conn) { - //log.Printf("connection from %s ", inConn.RemoteAddr()) - - reader := bufio.NewReader(inConn) - var err error - var connType uint8 - err = utils.ReadPacket(reader, &connType) - if err != nil { - log.Printf("read error,ERR:%s", err) - return - } - switch connType { - case CONN_SERVER: - var key, ID, clientLocalAddr, serverID string - err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID) - if err != nil { - log.Printf("read error,ERR:%s", err) - return - } - packet := utils.BuildPacketData(ID, clientLocalAddr, serverID) - log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID) - - //addr := clientLocalAddr + "@" + ID - s.serverConns.Set(ID, ServerConn{ - Conn: &inConn, - }) - for { - if s.isStop { - return - } - item, ok := s.clientControlConns.Get(key) - if !ok { - log.Printf("client %s control conn not exists", key) - time.Sleep(time.Second * 3) - continue - } - (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) - _, err := (*item.(*net.Conn)).Write(packet) - (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) - if err != nil { - log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) - time.Sleep(time.Second * 3) - continue - } else { - // s.cmServer.Add(serverID, ID, &inConn) - break - } - } - case CONN_CLIENT: - var key, ID, serverID string - err = utils.ReadPacketData(reader, &key, &ID, &serverID) - if err != nil { - log.Printf("read error,ERR:%s", err) - return - } - log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID) - - serverConnItem, ok := s.serverConns.Get(ID) - if !ok { - inConn.Close() - log.Printf("server conn %s exists", ID) - return - } - serverConn := serverConnItem.(ServerConn).Conn - utils.IoBind(*serverConn, inConn, func(err interface{}) { - s.serverConns.Remove(ID) - // s.cmClient.RemoveOne(key, ID) - // s.cmServer.RemoveOne(serverID, ID) - log.Printf("conn %s released", ID) - }) - // s.cmClient.Add(key, ID, &inConn) - log.Printf("conn %s created", ID) - - case CONN_CLIENT_CONTROL: - var key string - err = utils.ReadPacketData(reader, &key) - if err != nil { - log.Printf("read error,ERR:%s", err) - return - } - log.Printf("client control connection, key: %s", key) - if s.clientControlConns.Has(key) { - item, _ := s.clientControlConns.Get(key) - (*item.(*net.Conn)).Close() - } - s.clientControlConns.Set(key, &inConn) - log.Printf("set client %s control conn", key) - } - }) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback) if err != nil { return } @@ -168,3 +80,92 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { func (s *TunnelBridge) Clean() { s.StopService() } +func (s *TunnelBridge) callback(inConn net.Conn) { + //log.Printf("connection from %s ", inConn.RemoteAddr()) + + reader := bufio.NewReader(inConn) + var err error + var connType uint8 + err = utils.ReadPacket(reader, &connType) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + switch connType { + case CONN_SERVER: + var key, ID, clientLocalAddr, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + packet := utils.BuildPacketData(ID, clientLocalAddr, serverID) + log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID) + + //addr := clientLocalAddr + "@" + ID + s.serverConns.Set(ID, ServerConn{ + Conn: &inConn, + }) + for { + if s.isStop { + return + } + item, ok := s.clientControlConns.Get(key) + if !ok { + log.Printf("client %s control conn not exists", key) + time.Sleep(time.Second * 3) + continue + } + (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) + _, err := (*item.(*net.Conn)).Write(packet) + (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) + if err != nil { + log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) + time.Sleep(time.Second * 3) + continue + } else { + // s.cmServer.Add(serverID, ID, &inConn) + break + } + } + case CONN_CLIENT: + var key, ID, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID) + + serverConnItem, ok := s.serverConns.Get(ID) + if !ok { + inConn.Close() + log.Printf("server conn %s exists", ID) + return + } + serverConn := serverConnItem.(ServerConn).Conn + utils.IoBind(*serverConn, inConn, func(err interface{}) { + s.serverConns.Remove(ID) + // s.cmClient.RemoveOne(key, ID) + // s.cmServer.RemoveOne(serverID, ID) + log.Printf("conn %s released", ID) + }) + // s.cmClient.Add(key, ID, &inConn) + log.Printf("conn %s created", ID) + + case CONN_CLIENT_CONTROL: + var key string + err = utils.ReadPacketData(reader, &key) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client control connection, key: %s", key) + if s.clientControlConns.Has(key) { + item, _ := s.clientControlConns.Get(key) + (*item.(*net.Conn)).Close() + } + s.clientControlConns.Set(key, &inConn) + log.Printf("set client %s control conn", key) + } +}