package services import ( "crypto/tls" "fmt" "io" "log" "net" "proxy/utils" "time" ) type TunnelClient struct { cfg TunnelClientArgs cm utils.ConnManager ctrlConn net.Conn } func NewTunnelClient() Service { return &TunnelClient{ cfg: TunnelClientArgs{}, cm: utils.NewConnManager(), } } func (s *TunnelClient) InitService() { s.InitHeartbeatDeamon() } func (s *TunnelClient) InitHeartbeatDeamon() { log.Printf("heartbeat started") go func() { var heartbeatConn net.Conn var ID = *s.cfg.Key for { //close all connection s.cm.RemoveAll() if s.ctrlConn != nil { s.ctrlConn.Close() } utils.CloseConn(&heartbeatConn) heartbeatConn, err := s.GetInConn(CONN_CLIENT_HEARBEAT, ID) if err != nil { log.Printf("heartbeat connection err: %s, retrying...", err) time.Sleep(time.Second * 3) utils.CloseConn(&heartbeatConn) continue } log.Printf("heartbeat connection created,id:%s", ID) writeDie := make(chan bool) readDie := make(chan bool) go func() { for { heartbeatConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) _, err = heartbeatConn.Write([]byte{0x00}) heartbeatConn.SetWriteDeadline(time.Time{}) if err != nil { log.Printf("heartbeat connection write err %s", err) break } time.Sleep(time.Second * 3) } close(writeDie) }() go func() { for { signal := make([]byte, 1) heartbeatConn.SetReadDeadline(time.Now().Add(time.Second * 6)) _, err := heartbeatConn.Read(signal) heartbeatConn.SetReadDeadline(time.Time{}) if err != nil { log.Printf("heartbeat connection read err: %s", err) break } else { //log.Printf("heartbeat from bridge") } } close(readDie) }() select { case <-readDie: case <-writeDie: } } }() } func (s *TunnelClient) CheckArgs() { if *s.cfg.Parent != "" { log.Printf("use tls parent %s", *s.cfg.Parent) } else { log.Fatalf("parent required") } if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { log.Fatalf("cert and key file required") } s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) } func (s *TunnelClient) StopService() { s.cm.RemoveAll() } func (s *TunnelClient) Start(args interface{}) (err error) { s.cfg = args.(TunnelClientArgs) s.CheckArgs() s.InitService() log.Printf("proxy on tunnel client mode") for { //close all conn s.cm.Remove(*s.cfg.Key) if s.ctrlConn != nil { s.ctrlConn.Close() } s.ctrlConn, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key) if err != nil { log.Printf("control connection err: %s, retrying...", err) time.Sleep(time.Second * 3) if s.ctrlConn != nil { s.ctrlConn.Close() } continue } for { var ID, clientLocalAddr, serverID string err = utils.ReadPacketData(s.ctrlConn, &ID, &clientLocalAddr, &serverID) if err != nil { if s.ctrlConn != nil { s.ctrlConn.Close() } log.Printf("read connection signal err: %s, retrying...", err) break } log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr) protocol := clientLocalAddr[:3] localAddr := clientLocalAddr[4:] if protocol == "udp" { go s.ServeUDP(localAddr, ID, serverID) } else { go s.ServeConn(localAddr, ID, serverID) } } } } func (s *TunnelClient) Clean() { s.StopService() } func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, err error) { outConn, err = s.GetConn() if err != nil { err = fmt.Errorf("connection err: %s", err) return } _, err = outConn.Write(utils.BuildPacket(typ, data...)) if err != nil { err = fmt.Errorf("write connection data err: %s ,retrying...", err) utils.CloseConn(&outConn) return } return } func (s *TunnelClient) GetConn() (conn net.Conn, err error) { var _conn tls.Conn _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) if err == nil { conn = net.Conn(&_conn) } return } func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) { var inConn net.Conn var err error // for { for { s.cm.RemoveOne(*s.cfg.Key, ID) inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) if err != nil { utils.CloseConn(&inConn) log.Printf("connection err: %s, retrying...", err) time.Sleep(time.Second * 3) continue } else { break } } s.cm.Add(*s.cfg.Key, ID, &inConn) log.Printf("conn %s created", ID) for { srcAddr, body, err := utils.ReadUDPPacket(inConn) if err == io.EOF || err == io.ErrUnexpectedEOF { log.Printf("connection %s released", ID) utils.CloseConn(&inConn) break } else if err != nil { log.Printf("udp packet revecived fail, err: %s", err) } else { //log.Printf("udp packet revecived:%s,%v", srcAddr, body) go s.processUDPPacket(&inConn, srcAddr, localAddr, body) } } // } } func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr string, body []byte) { dstAddr, err := net.ResolveUDPAddr("udp", localAddr) if err != nil { log.Printf("can't resolve address: %s", err) utils.CloseConn(inConn) return } clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) if err != nil { log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) return } conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) _, err = conn.Write(body) if err != nil { log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) return } //log.Printf("send udp packet to %s success", dstAddr.String()) buf := make([]byte, 1024) length, _, err := conn.ReadFromUDP(buf) if err != nil { log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) return } respBody := buf[0:length] //log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody) bs := utils.UDPPacket(srcAddr, respBody) _, err = (*inConn).Write(bs) if err != nil { log.Printf("send udp response fail ,ERR:%s", err) utils.CloseConn(inConn) return } //log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs) } func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) { var inConn, outConn net.Conn var err error for { inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) if err != nil { utils.CloseConn(&inConn) log.Printf("connection err: %s, retrying...", err) time.Sleep(time.Second * 3) continue } else { break } } i := 0 for { i++ outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout) if err == nil || i == 3 { break } else { if i == 3 { log.Printf("connect to %s err: %s, retrying...", localAddr, err) time.Sleep(2 * time.Second) continue } } } if err != nil { utils.CloseConn(&inConn) utils.CloseConn(&outConn) log.Printf("build connection error, err: %s", err) return } utils.IoBind(inConn, outConn, func(err error) { log.Printf("conn %s released", ID) utils.CloseConn(&inConn) utils.CloseConn(&outConn) s.cm.RemoveOne(*s.cfg.Key, ID) }) s.cm.Add(*s.cfg.Key, ID, &inConn) log.Printf("conn %s created", ID) }