From e2cd0b8e4fdbeb4af908f915b051c3fb0e313030 Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Thu, 30 Nov 2017 16:50:47 +0800 Subject: [PATCH] Signed-off-by: arraykeys@gmail.com --- config.go | 34 +++- services/args.go | 32 ++++ services/mux_bridge.go | 243 +++++++++++++++++++++++++++ services/mux_client.go | 282 ++++++++++++++++++++++++++++++++ services/mux_server.go | 361 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 949 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index 396019f..d9694b0 100755 --- a/config.go +++ b/config.go @@ -34,6 +34,9 @@ func initConfig() (err error) { tunnelServerArgs := services.TunnelServerArgs{} tunnelClientArgs := services.TunnelClientArgs{} tunnelBridgeArgs := services.TunnelBridgeArgs{} + muxServerArgs := services.MuxServerArgs{} + muxClientArgs := services.MuxClientArgs{} + muxBridgeArgs := services.MuxBridgeArgs{} udpArgs := services.UDPArgs{} socksArgs := services.SocksArgs{} //build srvice args @@ -99,6 +102,31 @@ func initConfig() (err error) { udpArgs.CheckParentInterval = udp.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int() udpArgs.Local = udp.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String() + //########mux-server######### + muxServer := app.Command("server", "proxy on mux server mode") + muxServerArgs.Parent = muxServer.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() + muxServerArgs.CertFile = muxServer.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() + muxServerArgs.KeyFile = muxServer.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() + muxServerArgs.Timeout = muxServer.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() + muxServerArgs.IsUDP = muxServer.Flag("udp", "proxy on udp mux server mode").Default("false").Bool() + muxServerArgs.Key = muxServer.Flag("k", "client key").Default("default").String() + muxServerArgs.Route = muxServer.Flag("route", "local route to client's network, such as :PROTOCOL://LOCAL_IP:LOCAL_PORT@[CLIENT_KEY]CLIENT_LOCAL_HOST:CLIENT_LOCAL_PORT").Short('r').Default("").Strings() + + //########mux-client######### + muxClient := app.Command("client", "proxy on mux client mode") + muxClientArgs.Parent = muxClient.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() + muxClientArgs.CertFile = muxClient.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() + muxClientArgs.KeyFile = muxClient.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() + muxClientArgs.Timeout = muxClient.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() + muxClientArgs.Key = muxClient.Flag("k", "key same with server").Default("default").String() + + //########mux-bridge######### + muxBridge := app.Command("bridge", "proxy on mux bridge mode") + muxBridgeArgs.CertFile = muxBridge.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() + muxBridgeArgs.KeyFile = muxBridge.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() + muxBridgeArgs.Timeout = muxBridge.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() + muxBridgeArgs.Local = muxBridge.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String() + //########tunnel-server######### tunnelServer := app.Command("tserver", "proxy on tunnel server mode") tunnelServerArgs.Parent = tunnelServer.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() @@ -108,7 +136,6 @@ func initConfig() (err error) { tunnelServerArgs.IsUDP = tunnelServer.Flag("udp", "proxy on udp tunnel server mode").Default("false").Bool() tunnelServerArgs.Key = tunnelServer.Flag("k", "client key").Default("default").String() tunnelServerArgs.Route = tunnelServer.Flag("route", "local route to client's network, such as :PROTOCOL://LOCAL_IP:LOCAL_PORT@[CLIENT_KEY]CLIENT_LOCAL_HOST:CLIENT_LOCAL_PORT").Short('r').Default("").Strings() - tunnelServerArgs.Mux = tunnelServer.Flag("mux", "use multiplexing mode").Default("false").Bool() //########tunnel-client######### tunnelClient := app.Command("tclient", "proxy on tunnel client mode") @@ -117,7 +144,6 @@ func initConfig() (err error) { tunnelClientArgs.KeyFile = tunnelClient.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() tunnelClientArgs.Timeout = tunnelClient.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() tunnelClientArgs.Key = tunnelClient.Flag("k", "key same with server").Default("default").String() - tunnelClientArgs.Mux = tunnelClient.Flag("mux", "use multiplexing mode").Default("false").Bool() //########tunnel-bridge######### tunnelBridge := app.Command("tbridge", "proxy on tunnel bridge mode") @@ -125,7 +151,6 @@ func initConfig() (err error) { tunnelBridgeArgs.KeyFile = tunnelBridge.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() tunnelBridgeArgs.Timeout = tunnelBridge.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() tunnelBridgeArgs.Local = tunnelBridge.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String() - tunnelBridgeArgs.Mux = tunnelBridge.Flag("mux", "use multiplexing mode").Default("false").Bool() //########ssh######### socks := app.Command("socks", "proxy on ssh mode") @@ -250,6 +275,9 @@ func initConfig() (err error) { services.Regist("tserver", services.NewTunnelServerManager(), tunnelServerArgs) services.Regist("tclient", services.NewTunnelClient(), tunnelClientArgs) services.Regist("tbridge", services.NewTunnelBridge(), tunnelBridgeArgs) + services.Regist("server", services.NewMuxServerManager(), muxServerArgs) + services.Regist("client", services.NewMuxClient(), muxClientArgs) + services.Regist("bridge", services.NewMuxBridge(), muxBridgeArgs) services.Regist("socks", services.NewSocks(), socksArgs) service, err = services.Run(serviceName) if err != nil { diff --git a/services/args.go b/services/args.go index a1e732c..d5cb9d7 100644 --- a/services/args.go +++ b/services/args.go @@ -20,6 +20,38 @@ const ( CONN_CLIENT_MUX = uint8(7) ) +type MuxServerArgs struct { + Parent *string + CertFile *string + KeyFile *string + CertBytes []byte + KeyBytes []byte + Local *string + IsUDP *bool + Key *string + Remote *string + Timeout *int + Route *[]string + Mgr *MuxServerManager +} +type MuxClientArgs struct { + Parent *string + CertFile *string + KeyFile *string + CertBytes []byte + KeyBytes []byte + Key *string + Timeout *int +} +type MuxBridgeArgs struct { + Parent *string + CertFile *string + KeyFile *string + CertBytes []byte + KeyBytes []byte + Local *string + Timeout *int +} type TunnelServerArgs struct { Parent *string CertFile *string diff --git a/services/mux_bridge.go b/services/mux_bridge.go index 5e568ea..11ab8dc 100644 --- a/services/mux_bridge.go +++ b/services/mux_bridge.go @@ -1 +1,244 @@ package services + +import ( + "bufio" + "log" + "net" + "proxy/utils" + "strconv" + "time" +) + +type MuxServerConn struct { + //ClientLocalAddr string //tcp:2.2.22:333@ID + Conn *net.Conn +} +type MuxBridge struct { + cfg MuxBridgeArgs + serverConns utils.ConcurrentMap + clientControlConns utils.ConcurrentMap + // cmServer utils.ConnManager + // cmClient utils.ConnManager +} + +func NewMuxBridge() Service { + return &MuxBridge{ + cfg: MuxBridgeArgs{}, + serverConns: utils.NewConcurrentMap(), + clientControlConns: utils.NewConcurrentMap(), + // cmServer: utils.NewConnManager(), + // cmClient: utils.NewConnManager(), + } +} + +func (s *MuxBridge) InitService() { + +} +func (s *MuxBridge) CheckArgs() { + if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { + log.Fatalf("cert and key file required") + } + s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) +} +func (s *MuxBridge) StopService() { + +} +func (s *MuxBridge) Start(args interface{}) (err error) { + s.cfg = args.(MuxBridgeArgs) + s.CheckArgs() + s.InitService() + host, port, _ := net.SplitHostPort(*s.cfg.Local) + p, _ := strconv.Atoi(port) + sc := utils.NewServerChannel(host, p) + + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) { + //log.Printf("connection from %s ", inConn.RemoteAddr()) + + reader := bufio.NewReader(inConn) + var err error + var connType uint8 + err = utils.ReadPacket(reader, &connType) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + switch connType { + case CONN_SERVER: + var key, ID, clientLocalAddr, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + packet := utils.BuildPacketData(ID, clientLocalAddr, serverID) + log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID) + + //addr := clientLocalAddr + "@" + ID + s.serverConns.Set(ID, MuxServerConn{ + Conn: &inConn, + }) + for { + item, ok := s.clientControlConns.Get(key) + if !ok { + log.Printf("client %s control conn not exists", key) + time.Sleep(time.Second * 3) + continue + } + (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) + _, err := (*item.(*net.Conn)).Write(packet) + (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) + if err != nil { + log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) + time.Sleep(time.Second * 3) + continue + } else { + // s.cmServer.Add(serverID, ID, &inConn) + break + } + } + case CONN_CLIENT: + var key, ID, serverID string + err = utils.ReadPacketData(reader, &key, &ID, &serverID) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID) + + serverConnItem, ok := s.serverConns.Get(ID) + if !ok { + inConn.Close() + log.Printf("server conn %s exists", ID) + return + } + serverConn := serverConnItem.(MuxServerConn).Conn + utils.IoBind(*serverConn, inConn, func(err interface{}) { + s.serverConns.Remove(ID) + // s.cmClient.RemoveOne(key, ID) + // s.cmServer.RemoveOne(serverID, ID) + log.Printf("conn %s released", ID) + }) + // s.cmClient.Add(key, ID, &inConn) + log.Printf("conn %s created", ID) + + case CONN_CLIENT_CONTROL: + var key string + err = utils.ReadPacketData(reader, &key) + if err != nil { + log.Printf("read error,ERR:%s", err) + return + } + log.Printf("client control connection, key: %s", key) + if s.clientControlConns.Has(key) { + item, _ := s.clientControlConns.Get(key) + (*item.(*net.Conn)).Close() + } + s.clientControlConns.Set(key, &inConn) + log.Printf("set client %s control conn", key) + + // case CONN_SERVER_HEARBEAT: + // var serverID string + // err = utils.ReadPacketData(reader, &serverID) + // if err != nil { + // log.Printf("read error,ERR:%s", err) + // return + // } + // log.Printf("server heartbeat connection, id: %s", serverID) + // writeDie := make(chan bool) + // readDie := make(chan bool) + // go func() { + // for { + // inConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) + // _, err = inConn.Write([]byte{0x00}) + // inConn.SetWriteDeadline(time.Time{}) + // if err != nil { + // log.Printf("server heartbeat connection write err %s", err) + // break + // } + // time.Sleep(time.Second * 3) + // } + // close(writeDie) + // }() + // go func() { + // for { + // signal := make([]byte, 1) + // inConn.SetReadDeadline(time.Now().Add(time.Second * 6)) + // _, err := inConn.Read(signal) + // inConn.SetReadDeadline(time.Time{}) + // if err != nil { + // log.Printf("server heartbeat connection read err: %s", err) + // break + // } else { + // // log.Printf("heartbeat from server ,id:%s", serverID) + // } + // } + // close(readDie) + // }() + // select { + // case <-readDie: + // case <-writeDie: + // } + // utils.CloseConn(&inConn) + // s.cmServer.Remove(serverID) + // log.Printf("server heartbeat conn %s released", serverID) + // case CONN_CLIENT_HEARBEAT: + // var clientID string + // err = utils.ReadPacketData(reader, &clientID) + // if err != nil { + // log.Printf("read error,ERR:%s", err) + // return + // } + // log.Printf("client heartbeat connection, id: %s", clientID) + // writeDie := make(chan bool) + // readDie := make(chan bool) + // go func() { + // for { + // inConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) + // _, err = inConn.Write([]byte{0x00}) + // inConn.SetWriteDeadline(time.Time{}) + // if err != nil { + // log.Printf("client heartbeat connection write err %s", err) + // break + // } + // time.Sleep(time.Second * 3) + // } + // close(writeDie) + // }() + // go func() { + // for { + // signal := make([]byte, 1) + // inConn.SetReadDeadline(time.Now().Add(time.Second * 6)) + // _, err := inConn.Read(signal) + // inConn.SetReadDeadline(time.Time{}) + // if err != nil { + // log.Printf("client control connection read err: %s", err) + // break + // } else { + // // log.Printf("heartbeat from client ,id:%s", clientID) + // } + // } + // close(readDie) + // }() + // select { + // case <-readDie: + // case <-writeDie: + // } + // utils.CloseConn(&inConn) + // s.cmClient.Remove(clientID) + // if s.clientControlConns.Has(clientID) { + // item, _ := s.clientControlConns.Get(clientID) + // (*item.(*net.Conn)).Close() + // } + // s.clientControlConns.Remove(clientID) + // log.Printf("client heartbeat conn %s released", clientID) + } + }) + if err != nil { + return + } + log.Printf("proxy on tunnel bridge mode %s", (*sc.Listener).Addr()) + return +} +func (s *MuxBridge) Clean() { + s.StopService() +} diff --git a/services/mux_client.go b/services/mux_client.go index 5e568ea..08e05de 100644 --- a/services/mux_client.go +++ b/services/mux_client.go @@ -1 +1,283 @@ package services + +import ( + "crypto/tls" + "fmt" + "io" + "log" + "net" + "proxy/utils" + "time" +) + +type MuxClient struct { + cfg MuxClientArgs + // cm utils.ConnManager + ctrlConn net.Conn +} + +func NewMuxClient() Service { + return &MuxClient{ + cfg: MuxClientArgs{}, + // cm: utils.NewConnManager(), + } +} + +func (s *MuxClient) InitService() { + // s.InitHeartbeatDeamon() +} + +// func (s *MuxClient) InitHeartbeatDeamon() { +// log.Printf("heartbeat started") +// go func() { +// var heartbeatConn net.Conn +// var ID = *s.cfg.Key +// for { + +// //close all connection +// s.cm.RemoveAll() +// if s.ctrlConn != nil { +// s.ctrlConn.Close() +// } +// utils.CloseConn(&heartbeatConn) +// heartbeatConn, err := s.GetInConn(CONN_CLIENT_HEARBEAT, ID) +// if err != nil { +// log.Printf("heartbeat connection err: %s, retrying...", err) +// time.Sleep(time.Second * 3) +// utils.CloseConn(&heartbeatConn) +// continue +// } +// log.Printf("heartbeat connection created,id:%s", ID) +// writeDie := make(chan bool) +// readDie := make(chan bool) +// go func() { +// for { +// heartbeatConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) +// _, err = heartbeatConn.Write([]byte{0x00}) +// heartbeatConn.SetWriteDeadline(time.Time{}) +// if err != nil { +// log.Printf("heartbeat connection write err %s", err) +// break +// } +// time.Sleep(time.Second * 3) +// } +// close(writeDie) +// }() +// go func() { +// for { +// signal := make([]byte, 1) +// heartbeatConn.SetReadDeadline(time.Now().Add(time.Second * 6)) +// _, err := heartbeatConn.Read(signal) +// heartbeatConn.SetReadDeadline(time.Time{}) +// if err != nil { +// log.Printf("heartbeat connection read err: %s", err) +// break +// } else { +// //log.Printf("heartbeat from bridge") +// } +// } +// close(readDie) +// }() +// select { +// case <-readDie: +// case <-writeDie: +// } +// } +// }() +// } +func (s *MuxClient) CheckArgs() { + if *s.cfg.Parent != "" { + log.Printf("use tls parent %s", *s.cfg.Parent) + } else { + log.Fatalf("parent required") + } + if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { + log.Fatalf("cert and key file required") + } + s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) +} +func (s *MuxClient) StopService() { + // s.cm.RemoveAll() +} +func (s *MuxClient) Start(args interface{}) (err error) { + s.cfg = args.(MuxClientArgs) + s.CheckArgs() + s.InitService() + log.Printf("proxy on tunnel client mode") + + for { + //close all conn + // s.cm.Remove(*s.cfg.Key) + if s.ctrlConn != nil { + s.ctrlConn.Close() + } + + s.ctrlConn, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key) + if err != nil { + log.Printf("control connection err: %s, retrying...", err) + time.Sleep(time.Second * 3) + if s.ctrlConn != nil { + s.ctrlConn.Close() + } + continue + } + for { + var ID, clientLocalAddr, serverID string + err = utils.ReadPacketData(s.ctrlConn, &ID, &clientLocalAddr, &serverID) + if err != nil { + if s.ctrlConn != nil { + s.ctrlConn.Close() + } + log.Printf("read connection signal err: %s, retrying...", err) + break + } + log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr) + protocol := clientLocalAddr[:3] + localAddr := clientLocalAddr[4:] + if protocol == "udp" { + go s.ServeUDP(localAddr, ID, serverID) + } else { + go s.ServeConn(localAddr, ID, serverID) + } + } + } +} +func (s *MuxClient) Clean() { + s.StopService() +} +func (s *MuxClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, err error) { + outConn, err = s.GetConn() + if err != nil { + err = fmt.Errorf("connection err: %s", err) + return + } + _, err = outConn.Write(utils.BuildPacket(typ, data...)) + if err != nil { + err = fmt.Errorf("write connection data err: %s ,retrying...", err) + utils.CloseConn(&outConn) + return + } + return +} +func (s *MuxClient) GetConn() (conn net.Conn, err error) { + var _conn tls.Conn + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + if err == nil { + conn = net.Conn(&_conn) + } + return +} +func (s *MuxClient) ServeUDP(localAddr, ID, serverID string) { + var inConn net.Conn + var err error + // for { + for { + // s.cm.RemoveOne(*s.cfg.Key, ID) + inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) + if err != nil { + utils.CloseConn(&inConn) + log.Printf("connection err: %s, retrying...", err) + time.Sleep(time.Second * 3) + continue + } else { + break + } + } + // s.cm.Add(*s.cfg.Key, ID, &inConn) + log.Printf("conn %s created", ID) + + for { + srcAddr, body, err := utils.ReadUDPPacket(inConn) + if err == io.EOF || err == io.ErrUnexpectedEOF { + log.Printf("connection %s released", ID) + utils.CloseConn(&inConn) + break + } else if err != nil { + log.Printf("udp packet revecived fail, err: %s", err) + } else { + //log.Printf("udp packet revecived:%s,%v", srcAddr, body) + go s.processUDPPacket(&inConn, srcAddr, localAddr, body) + } + + } + // } +} +func (s *MuxClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr string, body []byte) { + dstAddr, err := net.ResolveUDPAddr("udp", localAddr) + if err != nil { + log.Printf("can't resolve address: %s", err) + utils.CloseConn(inConn) + return + } + clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) + if err != nil { + log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) + return + } + conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) + _, err = conn.Write(body) + if err != nil { + log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) + return + } + //log.Printf("send udp packet to %s success", dstAddr.String()) + buf := make([]byte, 1024) + length, _, err := conn.ReadFromUDP(buf) + if err != nil { + log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) + return + } + respBody := buf[0:length] + //log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody) + bs := utils.UDPPacket(srcAddr, respBody) + _, err = (*inConn).Write(bs) + if err != nil { + log.Printf("send udp response fail ,ERR:%s", err) + utils.CloseConn(inConn) + return + } + //log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs) +} +func (s *MuxClient) ServeConn(localAddr, ID, serverID string) { + var inConn, outConn net.Conn + var err error + for { + inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) + if err != nil { + utils.CloseConn(&inConn) + log.Printf("connection err: %s, retrying...", err) + time.Sleep(time.Second * 3) + continue + } else { + break + } + } + + i := 0 + for { + i++ + outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout) + if err == nil || i == 3 { + break + } else { + if i == 3 { + log.Printf("connect to %s err: %s, retrying...", localAddr, err) + time.Sleep(2 * time.Second) + continue + } + } + } + if err != nil { + utils.CloseConn(&inConn) + utils.CloseConn(&outConn) + log.Printf("build connection error, err: %s", err) + return + } + utils.IoBind(inConn, outConn, func(err interface{}) { + log.Printf("conn %s released", ID) + // s.cm.RemoveOne(*s.cfg.Key, ID) + }) + // s.cm.Add(*s.cfg.Key, ID, &inConn) + log.Printf("conn %s created", ID) +} diff --git a/services/mux_server.go b/services/mux_server.go index 5e568ea..3e08193 100644 --- a/services/mux_server.go +++ b/services/mux_server.go @@ -1 +1,362 @@ package services + +import ( + "crypto/tls" + "fmt" + "io" + "log" + "net" + "proxy/utils" + "runtime/debug" + "strconv" + "strings" + "time" +) + +type MuxServer struct { + cfg MuxServerArgs + udpChn chan MuxUDPItem + sc utils.ServerChannel +} + +type MuxServerManager struct { + cfg MuxServerArgs + udpChn chan MuxUDPItem + sc utils.ServerChannel + serverID string + // cm utils.ConnManager +} + +func NewMuxServerManager() Service { + return &MuxServerManager{ + cfg: MuxServerArgs{}, + udpChn: make(chan MuxUDPItem, 50000), + serverID: utils.Uniqueid(), + // cm: utils.NewConnManager(), + } +} +func (s *MuxServerManager) Start(args interface{}) (err error) { + s.cfg = args.(MuxServerArgs) + s.CheckArgs() + if *s.cfg.Parent != "" { + log.Printf("use tls parent %s", *s.cfg.Parent) + } else { + log.Fatalf("parent required") + } + + s.InitService() + + log.Printf("server id: %s", s.serverID) + //log.Printf("route:%v", *s.cfg.Route) + for _, _info := range *s.cfg.Route { + IsUDP := *s.cfg.IsUDP + if strings.HasPrefix(_info, "udp://") { + IsUDP = true + } + info := strings.TrimPrefix(_info, "udp://") + info = strings.TrimPrefix(info, "tcp://") + _routeInfo := strings.Split(info, "@") + server := NewMuxServer() + local := _routeInfo[0] + remote := _routeInfo[1] + KEY := *s.cfg.Key + if strings.HasPrefix(remote, "[") { + KEY = remote[1:strings.LastIndex(remote, "]")] + remote = remote[strings.LastIndex(remote, "]")+1:] + } + if strings.HasPrefix(remote, ":") { + remote = fmt.Sprintf("127.0.0.1%s", remote) + } + err = server.Start(MuxServerArgs{ + CertBytes: s.cfg.CertBytes, + KeyBytes: s.cfg.KeyBytes, + Parent: s.cfg.Parent, + CertFile: s.cfg.CertFile, + KeyFile: s.cfg.KeyFile, + Local: &local, + IsUDP: &IsUDP, + Remote: &remote, + Key: &KEY, + Timeout: s.cfg.Timeout, + Mgr: s, + }) + + if err != nil { + return + } + } + return +} +func (s *MuxServerManager) Clean() { + s.StopService() +} +func (s *MuxServerManager) StopService() { + // s.cm.RemoveAll() +} +func (s *MuxServerManager) CheckArgs() { + if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { + log.Fatalf("cert and key file required") + } + s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) +} +func (s *MuxServerManager) InitService() { + // s.InitHeartbeatDeamon() +} + +// func (s *MuxServerManager) InitHeartbeatDeamon() { +// log.Printf("heartbeat started") +// go func() { +// var heartbeatConn net.Conn +// var ID string +// for { +// //close all connection +// s.cm.Remove(ID) +// utils.CloseConn(&heartbeatConn) +// heartbeatConn, ID, err := s.GetOutConn(CONN_SERVER_HEARBEAT) +// if err != nil { +// log.Printf("heartbeat connection err: %s, retrying...", err) +// time.Sleep(time.Second * 3) +// utils.CloseConn(&heartbeatConn) +// continue +// } +// log.Printf("heartbeat connection created,id:%s", ID) +// writeDie := make(chan bool) +// readDie := make(chan bool) +// go func() { +// for { +// heartbeatConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) +// _, err = heartbeatConn.Write([]byte{0x00}) +// heartbeatConn.SetWriteDeadline(time.Time{}) +// if err != nil { +// log.Printf("heartbeat connection write err %s", err) +// break +// } +// time.Sleep(time.Second * 3) +// } +// close(writeDie) +// }() +// go func() { +// for { +// signal := make([]byte, 1) +// heartbeatConn.SetReadDeadline(time.Now().Add(time.Second * 6)) +// _, err := heartbeatConn.Read(signal) +// heartbeatConn.SetReadDeadline(time.Time{}) +// if err != nil { +// log.Printf("heartbeat connection read err: %s", err) +// break +// } else { +// // log.Printf("heartbeat from bridge") +// } +// } +// close(readDie) +// }() +// select { +// case <-readDie: +// case <-writeDie: +// } +// } +// }() +// } +func (s *MuxServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) { + outConn, err = s.GetConn() + if err != nil { + log.Printf("connection err: %s", err) + return + } + ID = s.serverID + _, err = outConn.Write(utils.BuildPacket(typ, s.serverID)) + if err != nil { + log.Printf("write connection data err: %s ,retrying...", err) + utils.CloseConn(&outConn) + return + } + return +} +func (s *MuxServerManager) GetConn() (conn net.Conn, err error) { + var _conn tls.Conn + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + if err == nil { + conn = net.Conn(&_conn) + } + return +} +func NewMuxServer() Service { + return &MuxServer{ + cfg: MuxServerArgs{}, + udpChn: make(chan MuxUDPItem, 50000), + } +} + +type MuxUDPItem struct { + packet *[]byte + localAddr *net.UDPAddr + srcAddr *net.UDPAddr +} + +func (s *MuxServer) InitService() { + s.UDPConnDeamon() +} +func (s *MuxServer) CheckArgs() { + if *s.cfg.Remote == "" { + log.Fatalf("remote required") + } +} + +func (s *MuxServer) Start(args interface{}) (err error) { + s.cfg = args.(MuxServerArgs) + s.CheckArgs() + s.InitService() + host, port, _ := net.SplitHostPort(*s.cfg.Local) + p, _ := strconv.Atoi(port) + s.sc = utils.NewServerChannel(host, p) + if *s.cfg.IsUDP { + err = s.sc.ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) { + s.udpChn <- MuxUDPItem{ + packet: &packet, + localAddr: localAddr, + srcAddr: srcAddr, + } + }) + if err != nil { + return + } + log.Printf("proxy on udp tunnel server mode %s", (*s.sc.UDPListener).LocalAddr()) + } else { + err = s.sc.ListenTCP(func(inConn net.Conn) { + defer func() { + if err := recover(); err != nil { + log.Printf("tserver conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) + } + }() + var outConn net.Conn + var ID string + for { + outConn, ID, err = s.GetOutConn(CONN_SERVER) + if err != nil { + utils.CloseConn(&outConn) + log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) + time.Sleep(time.Second * 3) + continue + } else { + break + } + } + utils.IoBind(inConn, outConn, func(err interface{}) { + // s.cfg.Mgr.cm.RemoveOne(s.cfg.Mgr.serverID, ID) + log.Printf("%s conn %s released", *s.cfg.Key, ID) + }) + //add conn + // s.cfg.Mgr.cm.Add(s.cfg.Mgr.serverID, ID, &inConn) + log.Printf("%s conn %s created", *s.cfg.Key, ID) + }) + if err != nil { + return + } + log.Printf("proxy on tunnel server mode %s", (*s.sc.Listener).Addr()) + } + return +} +func (s *MuxServer) Clean() { + +} +func (s *MuxServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) { + outConn, err = s.GetConn() + if err != nil { + log.Printf("connection err: %s", err) + return + } + remoteAddr := "tcp:" + *s.cfg.Remote + if *s.cfg.IsUDP { + remoteAddr = "udp:" + *s.cfg.Remote + } + ID = utils.Uniqueid() + _, err = outConn.Write(utils.BuildPacket(typ, *s.cfg.Key, ID, remoteAddr, s.cfg.Mgr.serverID)) + if err != nil { + log.Printf("write connection data err: %s ,retrying...", err) + utils.CloseConn(&outConn) + return + } + return +} +func (s *MuxServer) GetConn() (conn net.Conn, err error) { + var _conn tls.Conn + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + if err == nil { + conn = net.Conn(&_conn) + } + return +} +func (s *MuxServer) UDPConnDeamon() { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("udp conn deamon crashed with err : %s \nstack: %s", err, string(debug.Stack())) + } + }() + var outConn net.Conn + // var hb utils.HeartbeatReadWriter + var ID string + // var cmdChn = make(chan bool, 1000) + var err error + for { + item := <-s.udpChn + RETRY: + if outConn == nil { + for { + outConn, ID, err = s.GetOutConn(CONN_SERVER) + if err != nil { + // cmdChn <- true + outConn = nil + utils.CloseConn(&outConn) + log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) + time.Sleep(time.Second * 3) + continue + } else { + go func(outConn net.Conn, ID string) { + go func() { + // <-cmdChn + // outConn.Close() + }() + for { + srcAddrFromConn, body, err := utils.ReadUDPPacket(outConn) + if err == io.EOF || err == io.ErrUnexpectedEOF { + log.Printf("UDP deamon connection %s exited", ID) + break + } + if err != nil { + log.Printf("parse revecived udp packet fail, err: %s ,%v", err, body) + continue + } + //log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn) + _srcAddr := strings.Split(srcAddrFromConn, ":") + if len(_srcAddr) != 2 { + log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn) + continue + } + port, _ := strconv.Atoi(_srcAddr[1]) + dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port} + _, err = s.sc.UDPListener.WriteToUDP(body, dstAddr) + if err != nil { + log.Printf("udp response to local %s fail,ERR:%s", srcAddrFromConn, err) + continue + } + //log.Printf("udp response to local %s success , %v", srcAddrFromConn, body) + } + }(outConn, ID) + break + } + } + } + outConn.SetWriteDeadline(time.Now().Add(time.Second)) + _, err = outConn.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet)) + outConn.SetWriteDeadline(time.Time{}) + if err != nil { + utils.CloseConn(&outConn) + outConn = nil + log.Printf("write udp packet to %s fail ,flush err:%s ,retrying...", *s.cfg.Parent, err) + goto RETRY + } + //log.Printf("write packet %v", *item.packet) + } + }() +}