diff --git a/CHANGELOG b/CHANGELOG index 1251942..b65e724 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,8 +1,11 @@ proxy更新日志: v3.0 -1.增加了代理死循环检查,增强了安全性。 -2.重构了全部代码,为下一步的功能拓展做准备。 -3.此次更新不兼容2.x版本。 +1.此次更新不兼容2.x版本,重构了全部代码,架构更合理,利于功能模块的增加与维护。 +2.增加了代理死循环检查,增强了安全性。 +3.增加了反向代理模式(即:内网穿透),支持TCP和UDP两种协议,可以把任何局域网的机器A所在网络的任何端口 + 暴露到任何局域网的机器B的本地端口或暴露到任何公网VPS上。 +4.正向代理增加了UDP模式支持。 + v2.2 1.增加了强制使用上级代理参数always.可以使所有流量都走上级代理。 diff --git a/config.go b/config.go index 2402a94..3932660 100755 --- a/config.go +++ b/config.go @@ -21,7 +21,9 @@ func initConfig() (err error) { //define args tcpArgs := services.TCPArgs{} httpArgs := services.HTTPArgs{} - tunnelArgs := services.TunnelArgs{} + tunnelServerArgs := services.TunnelServerArgs{} + tunnelClientArgs := services.TunnelClientArgs{} + tunnelBridgeArgs := services.TunnelBridgeArgs{} udpArgs := services.UDPArgs{} //build srvice args @@ -31,8 +33,6 @@ func initConfig() (err error) { args.Local = app.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String() certTLS := app.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() keyTLS := app.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() - args.PoolSize = app.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Short('L').Default("50").Int() - args.CheckParentInterval = app.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int() //########http######### http := app.Command("http", "proxy on http mode") @@ -46,29 +46,53 @@ func initConfig() (err error) { httpArgs.Direct = http.Flag("direct", "direct domain file , one domain each line").Default("direct").Short('d').String() httpArgs.AuthFile = http.Flag("auth-file", "http basic auth file,\"username:password\" each line in file").Short('F').String() httpArgs.Auth = http.Flag("auth", "http basic auth username and password, mutiple user repeat -a ,such as: -a user1:pass1 -a user2:pass2").Short('a').Strings() + httpArgs.PoolSize = http.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Short('L').Default("20").Int() + httpArgs.CheckParentInterval = http.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int() //########tcp######### tcp := app.Command("tcp", "proxy on tcp mode") tcpArgs.Timeout = tcp.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Short('t').Default("2000").Int() tcpArgs.ParentType = tcp.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "udp") tcpArgs.IsTLS = tcp.Flag("tls", "proxy on tls mode").Default("false").Bool() + tcpArgs.PoolSize = tcp.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Short('L').Default("20").Int() + tcpArgs.CheckParentInterval = tcp.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int() + //########udp######### udp := app.Command("udp", "proxy on udp mode") udpArgs.Timeout = udp.Flag("timeout", "tcp timeout milliseconds when connect to parent proxy").Short('t').Default("2000").Int() udpArgs.ParentType = udp.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "udp") - //########tunnel######### - tunnel := app.Command("tcp", "proxy on tunnel mode") - tunnelArgs.Timeout = tunnel.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() + udpArgs.PoolSize = udp.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Short('L').Default("20").Int() + udpArgs.CheckParentInterval = udp.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int() + + //########tunnel-server######### + tunnelServer := app.Command("tserver", "proxy on tunnel server mode") + tunnelServerArgs.Timeout = tunnelServer.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() + tunnelServerArgs.IsUDP = tunnelServer.Flag("udp", "proxy on udp tunnel server mode").Default("false").Bool() + tunnelServerArgs.Key = tunnelServer.Flag("k", "key same with client").Default("default").String() + + //########tunnel-client######### + tunnelClient := app.Command("tclient", "proxy on tunnel client mode") + tunnelClientArgs.Timeout = tunnelClient.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() + tunnelClientArgs.IsUDP = tunnelClient.Flag("udp", "proxy on udp tunnel client mode").Default("false").Bool() + tunnelClientArgs.Key = tunnelClient.Flag("k", "key same with server").Default("default").String() + + //########tunnel-bridge######### + tunnelBridge := app.Command("tbridge", "proxy on tunnel bridge mode") + tunnelBridgeArgs.Timeout = tunnelBridge.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int() kingpin.MustParse(app.Parse(os.Args[1:])) if *certTLS != "" && *keyTLS != "" { args.CertBytes, args.KeyBytes = tlsBytes(*certTLS, *keyTLS) } + + //common args httpArgs.Args = args tcpArgs.Args = args - // tlsArgs.Args = args udpArgs.Args = args + tunnelBridgeArgs.Args = args + tunnelClientArgs.Args = args + tunnelServerArgs.Args = args //keygen if len(os.Args) > 1 { @@ -83,7 +107,9 @@ func initConfig() (err error) { services.Regist("http", services.NewHTTP(), httpArgs) services.Regist("tcp", services.NewTCP(), tcpArgs) services.Regist("udp", services.NewUDP(), udpArgs) - services.Regist("tunnel", services.NewTunnel(), tunnelArgs) + services.Regist("tserver", services.NewTunnelServer(), tunnelServerArgs) + services.Regist("tclient", services.NewTunnelClient(), tunnelClientArgs) + services.Regist("tbridge", services.NewTunnelBridge(), tunnelBridgeArgs) service, err = services.Run(serviceName) if err != nil { log.Fatalf("run service [%s] fail, ERR:%s", service, err) diff --git a/services/args.go b/services/args.go index 8352114..10ba9c4 100644 --- a/services/args.go +++ b/services/args.go @@ -4,48 +4,67 @@ package services // t := tcp.Flag("tcp-timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int() const ( - TYPE_TCP = "tcp" - TYPE_UDP = "udp" - TYPE_HTTP = "http" - TYPE_TLS = "tls" + TYPE_TCP = "tcp" + TYPE_UDP = "udp" + TYPE_HTTP = "http" + TYPE_TLS = "tls" + CONN_CONTROL = uint8(1) + CONN_SERVER = uint8(2) + CONN_CLIENT = uint8(3) ) type Args struct { - Local *string - Parent *string - CertBytes []byte - KeyBytes []byte - PoolSize *int - CheckParentInterval *int + Local *string + Parent *string + CertBytes []byte + KeyBytes []byte } -type TunnelArgs struct { +type TunnelServerArgs struct { + Args + IsUDP *bool + Key *string + Timeout *int +} +type TunnelClientArgs struct { + Args + IsUDP *bool + Key *string + Timeout *int +} +type TunnelBridgeArgs struct { Args Timeout *int } type TCPArgs struct { Args - Timeout *int - ParentType *string - IsTLS *bool + ParentType *string + IsTLS *bool + Timeout *int + PoolSize *int + CheckParentInterval *int } type HTTPArgs struct { Args - Always *bool - HTTPTimeout *int - Timeout *int - Interval *int - Blocked *string - Direct *string - AuthFile *string - Auth *[]string - ParentType *string - LocalType *string + Always *bool + HTTPTimeout *int + Interval *int + Blocked *string + Direct *string + AuthFile *string + Auth *[]string + ParentType *string + LocalType *string + Timeout *int + PoolSize *int + CheckParentInterval *int } type UDPArgs struct { Args - ParentType *string - Timeout *int + ParentType *string + Timeout *int + PoolSize *int + CheckParentInterval *int } func (a *TCPArgs) Protocol() string { diff --git a/services/http.go b/services/http.go index a0d3e32..621b5d8 100644 --- a/services/http.go +++ b/services/http.go @@ -139,7 +139,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut } else { outConn.Write(req.HeadBuf) } - utils.IoBind((*inConn), outConn, func(err error) { + utils.IoBind((*inConn), outConn, func(isSrcErr bool, err error) { log.Printf("conn %s - %s - %s -%s released [%s]", inAddr, inLocalAddr, outLocalAddr, outAddr, req.Host) utils.CloseConn(inConn) utils.CloseConn(&outConn) diff --git a/services/tcp.go b/services/tcp.go index 7a2b487..d505012 100644 --- a/services/tcp.go +++ b/services/tcp.go @@ -97,7 +97,7 @@ func (s *TCP) OutToTCP(inConn *net.Conn) (err error) { inLocalAddr := (*inConn).LocalAddr().String() outAddr := outConn.RemoteAddr().String() outLocalAddr := outConn.LocalAddr().String() - utils.IoBind((*inConn), outConn, func(err error) { + utils.IoBind((*inConn), outConn, func(isSrcErr bool, err error) { log.Printf("conn %s - %s - %s -%s released", inAddr, inLocalAddr, outLocalAddr, outAddr) utils.CloseConn(inConn) utils.CloseConn(&outConn) diff --git a/services/tunnel.go b/services/tunnel.go deleted file mode 100644 index a92d5b0..0000000 --- a/services/tunnel.go +++ /dev/null @@ -1,39 +0,0 @@ -package services - -import "log" - -type Tunnel struct { - cfg TunnelArgs -} - -func NewTunnel() Service { - return &Tunnel{ - cfg: TunnelArgs{}, - } -} - -func (s *Tunnel) InitService() { - -} -func (s *Tunnel) Check() { - if *s.cfg.Parent != "" { - log.Printf("use tls parent %s", *s.cfg.Parent) - } else { - log.Fatalf("parent required") - } - if s.cfg.CertBytes == nil || s.cfg.KeyBytes == nil { - log.Fatalf("cert and key file required") - } -} -func (s *Tunnel) StopService() { - -} -func (s *Tunnel) Start(args interface{}) (err error) { - s.cfg = args.(TunnelArgs) - s.Check() - s.InitService() - return -} -func (s *Tunnel) Clean() { - s.StopService() -} diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go new file mode 100644 index 0000000..d4fa865 --- /dev/null +++ b/services/tunnel_bridge.go @@ -0,0 +1,180 @@ +package services + +import ( + "bufio" + "encoding/binary" + "fmt" + "log" + "net" + "proxy/utils" + "strconv" + "sync" + "time" +) + +type BridgeItem struct { + ServerChn chan *net.Conn + ClientChn chan *net.Conn + ClientControl *net.Conn + Once *sync.Once + Key string +} +type TunnelBridge struct { + cfg TunnelBridgeArgs + br utils.ConcurrentMap +} + +func NewTunnelBridge() Service { + return &TunnelBridge{ + cfg: TunnelBridgeArgs{}, + br: utils.NewConcurrentMap(), + } +} + +func (s *TunnelBridge) InitService() { + +} +func (s *TunnelBridge) Check() { + if s.cfg.CertBytes == nil || s.cfg.KeyBytes == nil { + log.Fatalf("cert and key file required") + } + +} +func (s *TunnelBridge) StopService() { + +} +func (s *TunnelBridge) Start(args interface{}) (err error) { + s.cfg = args.(TunnelBridgeArgs) + s.Check() + s.InitService() + host, port, _ := net.SplitHostPort(*s.cfg.Local) + p, _ := strconv.Atoi(port) + sc := utils.NewServerChannel(host, p) + + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) { + reader := bufio.NewReader(inConn) + var connType uint8 + err = binary.Read(reader, binary.LittleEndian, &connType) + if err != nil { + utils.CloseConn(&inConn) + return + } + var key string + var connTypeStrMap = map[uint8]string{CONN_SERVER: "server", CONN_CLIENT: "client", CONN_CONTROL: "client"} + if connType == CONN_SERVER || connType == CONN_CLIENT || connType == CONN_CONTROL { + var keyLength uint16 + err = binary.Read(reader, binary.LittleEndian, &keyLength) + if err != nil { + return + } + _key := make([]byte, keyLength) + n, err := reader.Read(_key) + if err != nil { + return + } + if n != int(keyLength) { + return + } + key = string(_key) + log.Printf("connection from %s , key: %s", connTypeStrMap[connType], key) + } + switch connType { + case CONN_SERVER: + s.ServerConn(&inConn, key) + case CONN_CLIENT: + s.ClientConn(&inConn, key) + case CONN_CONTROL: + s.ClientControlConn(&inConn, key) + default: + log.Printf("unkown conn type %d", connType) + utils.CloseConn(&inConn) + } + }) + if err != nil { + return + } + log.Printf("proxy on tunnel bridge mode %s", (*sc.Listener).Addr()) + return +} +func (s *TunnelBridge) Clean() { + s.StopService() +} +func (s *TunnelBridge) ClientConn(inConn *net.Conn, key string) { + chn, _ := s.ConnChn(key, CONN_CLIENT) + chn <- inConn +} +func (s *TunnelBridge) ServerConn(inConn *net.Conn, key string) { + chn, _ := s.ConnChn(key, CONN_SERVER) + chn <- inConn +} +func (s *TunnelBridge) ClientControlConn(inConn *net.Conn, key string) { + _, item := s.ConnChn(key, CONN_CLIENT) + utils.CloseConn(item.ClientControl) + if item.ClientControl != nil { + *item.ClientControl = *inConn + } else { + item.ClientControl = inConn + } + log.Printf("set client control conn,remote: %s", (*inConn).RemoteAddr()) +} +func (s *TunnelBridge) ConnChn(key string, typ uint8) (chn chan *net.Conn, item *BridgeItem) { + s.br.SetIfAbsent(key, &BridgeItem{ + ServerChn: make(chan *net.Conn, 10000), + ClientChn: make(chan *net.Conn, 10000), + Once: &sync.Once{}, + Key: key, + }) + _item, _ := s.br.Get(key) + item = _item.(*BridgeItem) + item.Once.Do(func() { + s.ChnDeamon(item) + }) + if typ == CONN_CLIENT { + chn = item.ClientChn + } else { + chn = item.ServerChn + } + return +} +func (s *TunnelBridge) ChnDeamon(item *BridgeItem) { + go func() { + log.Printf("%s conn chan deamon started", item.Key) + for { + var clientConn *net.Conn + var serverConn *net.Conn + serverConn = <-item.ServerChn + log.Printf("%s server conn picked up", item.Key) + OUT: + for { + _item, _ := s.br.Get(item.Key) + Item := _item.(*BridgeItem) + var err error + if Item.ClientControl != nil && *Item.ClientControl != nil { + _, err = (*Item.ClientControl).Write([]byte{'0'}) + } else { + err = fmt.Errorf("client control conn not exists") + } + if err != nil { + log.Printf("%s client control conn write signal fail, err: %s, retrying...", item.Key, err) + time.Sleep(time.Second * 3) + continue + } else { + select { + case clientConn = <-item.ClientChn: + log.Printf("%s client conn picked up", item.Key) + break OUT + case <-time.After(time.Second * time.Duration(*s.cfg.Timeout*5)): + log.Printf("%s client conn picked timeout, retrying...", item.Key) + } + } + } + + utils.IoBind(*serverConn, *clientConn, func(isSrcErr bool, err error) { + utils.CloseConn(serverConn) + utils.CloseConn(clientConn) + log.Printf("%s conn %s - %s - %s - %s released", item.Key, (*serverConn).RemoteAddr(), (*serverConn).LocalAddr(), (*clientConn).LocalAddr(), (*clientConn).RemoteAddr()) + }, func(i int, b bool) {}, 0) + log.Printf("%s conn %s - %s - %s - %s created", item.Key, (*serverConn).RemoteAddr(), (*serverConn).LocalAddr(), (*clientConn).LocalAddr(), (*clientConn).RemoteAddr()) + } + }() +} diff --git a/services/tunnel_client.go b/services/tunnel_client.go new file mode 100644 index 0000000..f25a5e5 --- /dev/null +++ b/services/tunnel_client.go @@ -0,0 +1,214 @@ +package services + +import ( + "bytes" + "crypto/tls" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "proxy/utils" + "time" +) + +type TunnelClient struct { + cfg TunnelClientArgs +} + +func NewTunnelClient() Service { + return &TunnelClient{ + cfg: TunnelClientArgs{}, + } +} + +func (s *TunnelClient) InitService() { +} +func (s *TunnelClient) Check() { + if *s.cfg.Parent != "" { + log.Printf("use tls parent %s", *s.cfg.Parent) + } else { + log.Fatalf("parent required") + } + if s.cfg.CertBytes == nil || s.cfg.KeyBytes == nil { + log.Fatalf("cert and key file required") + } +} +func (s *TunnelClient) StopService() { +} +func (s *TunnelClient) Start(args interface{}) (err error) { + s.cfg = args.(TunnelClientArgs) + s.Check() + s.InitService() + + for { + ctrlConn, err := s.GetInConn(CONN_CONTROL) + if err != nil { + log.Printf("control connection err: %s", err) + time.Sleep(time.Second * 3) + utils.CloseConn(&ctrlConn) + continue + } + if *s.cfg.IsUDP { + log.Printf("proxy on udp tunnel client mode") + } else { + log.Printf("proxy on tcp tunnel client mode") + } + for { + signal := make([]byte, 1) + if signal[0] == 1 { + continue + } + _, err = ctrlConn.Read(signal) + if err != nil { + utils.CloseConn(&ctrlConn) + log.Printf("read connection signal err: %s", err) + break + } + log.Printf("signal revecived:%s", signal) + if *s.cfg.IsUDP { + go s.ServeUDP() + } else { + go s.ServeConn() + } + } + } +} +func (s *TunnelClient) Clean() { + s.StopService() +} +func (s *TunnelClient) GetInConn(typ uint8) (outConn net.Conn, err error) { + outConn, err = s.GetConn() + if err != nil { + err = fmt.Errorf("connection err: %s", err) + return + } + keyBytes := []byte(*s.cfg.Key) + keyLength := uint16(len(keyBytes)) + pkg := new(bytes.Buffer) + binary.Write(pkg, binary.LittleEndian, typ) + binary.Write(pkg, binary.LittleEndian, keyLength) + binary.Write(pkg, binary.LittleEndian, keyBytes) + _, err = outConn.Write(pkg.Bytes()) + if err != nil { + err = fmt.Errorf("write connection data err: %s ,retrying...", err) + utils.CloseConn(&outConn) + return + } + return +} +func (s *TunnelClient) GetConn() (conn net.Conn, err error) { + var _conn tls.Conn + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + if err == nil { + conn = net.Conn(&_conn) + } + return +} +func (s *TunnelClient) ServeUDP() { + var inConn net.Conn + var err error + for { + for { + inConn, err = s.GetInConn(CONN_CLIENT) + if err != nil { + utils.CloseConn(&inConn) + log.Printf("connection err: %s, retrying...", err) + time.Sleep(time.Second * 3) + continue + } else { + break + } + } + log.Printf("conn created , remote : %s ", inConn.RemoteAddr()) + for { + srcAddr, body, err := utils.ReadUDPPacket(&inConn) + if err == io.EOF || err == io.ErrUnexpectedEOF { + log.Printf("connection %s released", srcAddr) + utils.CloseConn(&inConn) + break + } + log.Printf("udp packet revecived:%s,%v", srcAddr, body) + go s.processUDPPacket(&inConn, srcAddr, body) + } + } +} +func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr string, body []byte) { + dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Local) + if err != nil { + log.Printf("can't resolve address: %s", err) + utils.CloseConn(inConn) + return + } + clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) + if err != nil { + log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) + return + } + conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) + _, err = conn.Write(body) + if err != nil { + log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) + return + } + //log.Debugf("send udp packet to %s success", dstAddr.String()) + buf := make([]byte, 512) + len, _, err := conn.ReadFromUDP(buf) + if err != nil { + log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) + return + } + respBody := buf[0:len] + //log.Debugf("revecived udp packet from %s , %v", dstAddr.String(), respBody) + _, err = (*inConn).Write(utils.UDPPacket(srcAddr, respBody)) + if err != nil { + log.Printf("send udp response fail ,ERR:%s", err) + utils.CloseConn(inConn) + return + } + log.Printf("send udp response success ,from:%s", dstAddr.String()) +} +func (s *TunnelClient) ServeConn() { + var inConn, outConn net.Conn + var err error + for { + inConn, err = s.GetInConn(CONN_CLIENT) + if err != nil { + utils.CloseConn(&inConn) + log.Printf("connection err: %s, retrying...", err) + time.Sleep(time.Second * 3) + continue + } else { + break + } + } + + i := 0 + for { + i++ + outConn, err = utils.ConnectHost(*s.cfg.Local, *s.cfg.Timeout) + if err == nil || i == 3 { + break + } else { + if i == 3 { + log.Printf("connect to %s err: %s, retrying...", *s.cfg.Local, err) + time.Sleep(2 * time.Second) + continue + } + } + } + + if err != nil { + utils.CloseConn(&inConn) + utils.CloseConn(&outConn) + return + } + + utils.IoBind(inConn, outConn, func(isSrcErr bool, err error) { + log.Printf("%s conn %s - %s - %s - %s released", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr()) + utils.CloseConn(&inConn) + utils.CloseConn(&outConn) + }, func(i int, b bool) {}, 0) + log.Printf("%s conn %s - %s - %s - %s created", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr()) +} diff --git a/services/tunnel_server.go b/services/tunnel_server.go new file mode 100644 index 0000000..643f168 --- /dev/null +++ b/services/tunnel_server.go @@ -0,0 +1,209 @@ +package services + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/binary" + "io" + "log" + "net" + "proxy/utils" + "runtime/debug" + "strconv" + "strings" + "time" +) + +type TunnelServer struct { + cfg TunnelServerArgs + udpChn chan UDPItem + sc utils.ServerChannel +} + +func NewTunnelServer() Service { + return &TunnelServer{ + cfg: TunnelServerArgs{}, + udpChn: make(chan UDPItem, 50000), + } +} + +type UDPItem struct { + packet *[]byte + localAddr *net.UDPAddr + srcAddr *net.UDPAddr +} + +func (s *TunnelServer) InitService() { + s.UDPConnDeamon() +} +func (s *TunnelServer) Check() { + if *s.cfg.Parent != "" { + log.Printf("use tls parent %s", *s.cfg.Parent) + } else { + log.Fatalf("parent required") + } + if s.cfg.CertBytes == nil || s.cfg.KeyBytes == nil { + log.Fatalf("cert and key file required") + } +} +func (s *TunnelServer) StopService() { +} +func (s *TunnelServer) Start(args interface{}) (err error) { + s.cfg = args.(TunnelServerArgs) + s.Check() + s.InitService() + host, port, _ := net.SplitHostPort(*s.cfg.Local) + p, _ := strconv.Atoi(port) + s.sc = utils.NewServerChannel(host, p) + + if *s.cfg.IsUDP { + err = s.sc.ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) { + s.udpChn <- UDPItem{ + packet: &packet, + localAddr: localAddr, + srcAddr: srcAddr, + } + }) + if err != nil { + return + } + log.Printf("proxy on udp tunnel server mode %s", (*s.sc.UDPListener).LocalAddr()) + } else { + err = s.sc.ListenTCP(func(inConn net.Conn) { + defer func() { + if err := recover(); err != nil { + log.Printf("tserver conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) + } + }() + var outConn net.Conn + for { + outConn, err = s.GetOutConn() + if err != nil { + utils.CloseConn(&outConn) + log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) + time.Sleep(time.Second * 3) + continue + } else { + break + } + } + + utils.IoBind(inConn, outConn, func(isSrcErr bool, err error) { + utils.CloseConn(&outConn) + utils.CloseConn(&inConn) + log.Printf("%s conn %s - %s - %s - %s released", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr()) + }, func(i int, b bool) {}, 0) + + log.Printf("%s conn %s - %s - %s - %s created", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr()) + }) + if err != nil { + return + } + log.Printf("proxy on tunnel server mode %s", (*s.sc.Listener).Addr()) + } + return +} +func (s *TunnelServer) Clean() { + s.StopService() +} +func (s *TunnelServer) GetOutConn() (outConn net.Conn, err error) { + outConn, err = s.GetConn() + if err != nil { + log.Printf("connection err: %s", err) + return + } + keyBytes := []byte(*s.cfg.Key) + keyLength := uint16(len(keyBytes)) + pkg := new(bytes.Buffer) + binary.Write(pkg, binary.LittleEndian, CONN_SERVER) + binary.Write(pkg, binary.LittleEndian, keyLength) + binary.Write(pkg, binary.LittleEndian, keyBytes) + _, err = outConn.Write(pkg.Bytes()) + if err != nil { + log.Printf("write connection data err: %s ,retrying...", err) + utils.CloseConn(&outConn) + return + } + return +} +func (s *TunnelServer) GetConn() (conn net.Conn, err error) { + var _conn tls.Conn + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + if err == nil { + conn = net.Conn(&_conn) + } + return +} +func (s *TunnelServer) UDPConnDeamon() { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("udp conn deamon crashed with err : %s \nstack: %s", err, string(debug.Stack())) + } + }() + var outConn net.Conn + var cmdChn = make(chan bool, 1) + + var err error + for { + item := <-s.udpChn + RETRY: + if outConn == nil { + for { + outConn, err = s.GetOutConn() + if err != nil { + cmdChn <- true + outConn = nil + utils.CloseConn(&outConn) + log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) + time.Sleep(time.Second * 3) + continue + } else { + go func(outConn net.Conn) { + go func() { + <-cmdChn + outConn.Close() + }() + for { + srcAddrFromConn, body, err := utils.ReadUDPPacket(&outConn) + if err == io.EOF || err == io.ErrUnexpectedEOF { + log.Printf("udp connection deamon exited, %s -> %s", outConn.LocalAddr(), outConn.RemoteAddr()) + break + } + if err != nil { + log.Printf("parse revecived udp packet fail, err: %s", err) + continue + } + log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn) + _srcAddr := strings.Split(srcAddrFromConn, ":") + if len(_srcAddr) != 2 { + log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn) + continue + } + port, _ := strconv.Atoi(_srcAddr[1]) + dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port} + _, err = s.sc.UDPListener.WriteToUDP(body, dstAddr) + if err != nil { + log.Printf("udp response to local %s fail,ERR:%s", srcAddrFromConn, err) + continue + } + log.Printf("udp response to local %s success", srcAddrFromConn) + } + }(outConn) + break + } + } + } + writer := bufio.NewWriter(outConn) + writer.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet)) + err := writer.Flush() + if err != nil { + outConn = nil + log.Printf("write udp packet to %s fail ,flush err:%s", *s.cfg.Parent, err) + goto RETRY + } + log.Printf("write packet %v", *item.packet) + } + }() +} diff --git a/utils/functions.go b/utils/functions.go index 31f8277..d98363c 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "os/exec" + "sync" "runtime/debug" "strconv" @@ -21,59 +22,61 @@ import ( "time" ) -func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(err error), cfn func(count int, isPositive bool), bytesPreSec float64) { +func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) { + var one = &sync.Once{} go func() { defer func() { if e := recover(); e != nil { log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() - errchn := make(chan error, 2) - go func() { - defer func() { - if e := recover(); e != nil { - log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) - } - }() - var err error - if bytesPreSec > 0 { - newreader := NewReader(src) - newreader.SetRateLimit(bytesPreSec) - _, err = ioCopy(dst, newreader, func(c int) { - cfn(c, false) - }) + var err error + var isSrcErr bool + if bytesPreSec > 0 { + newreader := NewReader(src) + newreader.SetRateLimit(bytesPreSec) + _, isSrcErr, err = ioCopy(dst, newreader, func(c int) { + cfn(c, false) + }) - } else { - _, err = ioCopy(dst, src, func(c int) { - cfn(c, false) - }) + } else { + _, isSrcErr, err = ioCopy(dst, src, func(c int) { + cfn(c, false) + }) + } + if err != nil { + one.Do(func() { + fn(isSrcErr, err) + }) + } + }() + go func() { + defer func() { + if e := recover(); e != nil { + log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } - errchn <- err }() - go func() { - defer func() { - if e := recover(); e != nil { - log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) - } - }() - var err error - if bytesPreSec > 0 { - newReader := NewReader(dst) - newReader.SetRateLimit(bytesPreSec) - _, err = ioCopy(src, newReader, func(c int) { - cfn(c, true) - }) - } else { - _, err = ioCopy(src, dst, func(c int) { - cfn(c, true) - }) - } - errchn <- err - }() - fn(<-errchn) + var err error + var isSrcErr bool + if bytesPreSec > 0 { + newReader := NewReader(dst) + newReader.SetRateLimit(bytesPreSec) + _, isSrcErr, err = ioCopy(src, newReader, func(c int) { + cfn(c, true) + }) + } else { + _, isSrcErr, err = ioCopy(src, dst, func(c int) { + cfn(c, true) + }) + } + if err != nil { + one.Do(func() { + fn(isSrcErr, err) + }) + } }() } -func ioCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, err error) { +func ioCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) { buf := make([]byte, 32*1024) for { nr, er := src.Read(buf) @@ -96,10 +99,11 @@ func ioCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, } if er != nil { err = er + isSrcErr = true break } } - return written, err + return written, isSrcErr, err } func TlsConnectHost(host string, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) { h := strings.Split(host, ":") @@ -191,7 +195,7 @@ func HTTPGet(URL string, timeout int) (err error) { } func CloseConn(conn *net.Conn) { - if *conn != nil { + if conn != nil && *conn != nil { (*conn).SetDeadline(time.Now().Add(time.Millisecond)) (*conn).Close() }