diff --git a/CHANGELOG b/CHANGELOG index 5122d0b..ab831d6 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,4 +1,10 @@ proxy更新日志 +v3.3 +1.修复了socks代理模式对证书文件的判断逻辑. +2.增强了http代理,socks代理的ssh中转模式的稳定性. +3.socks代理tls,tcp模式新增了CMD_ASSOCIATE(udp)支持.socks代理ssh模式不支持udp. +4.修复了http代理某些情况下会崩溃的bug. + v3.2 1.内网穿透功能server端-r参数增加了协议和key设置. 2.手册增加了对-r参数的详细说明. diff --git a/README.md b/README.md index 7fbb488..3c6fbb4 100644 --- a/README.md +++ b/README.md @@ -423,7 +423,9 @@ server连接到bridge的时候,如果同时有多个client连接到同一个brid `./proxy help socks` ### TODO -- SOCKS5增加用户名密码认证 +- SOCKS5增加用户名密码认证? +- http,socks代理多个上级负载均衡? +- 欢迎加群反馈... ### 如何使用源码? cd进入你的go src目录,然后git clone https://github.com/snail007/goproxy.git ./proxy 即可. diff --git a/config.go b/config.go index 190df03..64fe49d 100755 --- a/config.go +++ b/config.go @@ -35,7 +35,7 @@ func initConfig() (err error) { //build srvice args app = kingpin.New("proxy", "happy with proxy") app.Author("snail").Version(APP_VERSION) - + debug := app.Flag("debug", "debug log output").Default("false").Bool() //########http######### http := app.Command("http", "proxy on http mode") httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() @@ -126,7 +126,13 @@ func initConfig() (err error) { socksArgs.Direct = socks.Flag("direct", "direct domain file , one domain each line").Default("direct").Short('d').String() //parse args serviceName := kingpin.MustParse(app.Parse(os.Args[1:])) - + flags := log.Ldate + if *debug { + flags |= log.Lshortfile | log.Lmicroseconds + } else { + flags |= log.Ltime + } + log.SetFlags(flags) poster() //regist services and run service services.Regist("http", services.NewHTTP(), httpArgs) diff --git a/install_auto.sh b/install_auto.sh index 88241cd..b48a25d 100755 --- a/install_auto.sh +++ b/install_auto.sh @@ -6,7 +6,7 @@ fi mkdir /tmp/proxy cd /tmp/proxy wget https://github.com/reddec/monexec/releases/download/v0.1.1/monexec_0.1.1_linux_amd64.tar.gz -wget https://github.com/snail007/goproxy/releases/download/v3.2/proxy-linux-amd64.tar.gz +wget https://github.com/snail007/goproxy/releases/download/v3.3/proxy-linux-amd64.tar.gz # install monexec tar zxvf monexec_0.1.1_linux_amd64.tar.gz diff --git a/main.go b/main.go index 52c2d9d..de2ae51 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,7 @@ import ( "syscall" ) -const APP_VERSION = "3.2" +const APP_VERSION = "3.3" func main() { err := initConfig() diff --git a/release.sh b/release.sh index b72697a..ddf9960 100755 --- a/release.sh +++ b/release.sh @@ -1,5 +1,5 @@ #!/bin/bash -VER="3.2" +VER="3.3" RELEASE="release-${VER}" rm -rf .cert mkdir .cert diff --git a/services/http.go b/services/http.go index 9fbf27f..8106dad 100644 --- a/services/http.go +++ b/services/http.go @@ -20,6 +20,7 @@ type HTTP struct { checker utils.Checker basicAuth utils.BasicAuth sshClient *ssh.Client + lockChn chan bool } func NewHTTP() Service { @@ -28,6 +29,7 @@ func NewHTTP() Service { cfg: HTTPArgs{}, checker: utils.Checker{}, basicAuth: utils.BasicAuth{}, + lockChn: make(chan bool, 1), } } func (s *HTTP) CheckArgs() { @@ -115,7 +117,9 @@ func (s *HTTP) callback(inConn net.Conn) { log.Printf("http(s) conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) } }() - req, err := utils.NewHTTPRequest(&inConn, 4096, s.IsBasicAuth(), &s.basicAuth) + var err interface{} + var req utils.HTTPRequest + req, err = utils.NewHTTPRequest(&inConn, 4096, s.IsBasicAuth(), &s.basicAuth) if err != nil { if err != io.EOF { log.Printf("decoder error , form %s, ERR:%s", err, inConn.RemoteAddr()) @@ -153,7 +157,7 @@ func (s *HTTP) callback(inConn net.Conn) { utils.CloseConn(&inConn) } } -func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *utils.HTTPRequest) (err error) { +func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *utils.HTTPRequest) (err interface{}) { inAddr := (*inConn).RemoteAddr().String() inLocalAddr := (*inConn).LocalAddr().String() //防止死循环 @@ -208,18 +212,27 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut return } -func (s *HTTP) getSSHConn(host string) (outConn net.Conn, err error) { +func (s *HTTP) getSSHConn(host string) (outConn net.Conn, err interface{}) { maxTryCount := 1 tryCount := 0 + errchn := make(chan interface{}, 1) RETRY: if tryCount >= maxTryCount { return } - outConn, err = s.sshClient.Dial("tcp", host) - //log.Printf("s.sshClient.Dial, host:%s)", host) + go func() { + defer func() { + if err == nil { + errchn <- recover() + } else { + errchn <- nil + } + }() + outConn, err = s.sshClient.Dial("tcp", host) + }() + err = <-errchn if err != nil { log.Printf("connect ssh fail, ERR: %s, retrying...", err) - s.sshClient.Close() e := s.ConnectSSH() if e == nil { tryCount++ @@ -232,14 +245,25 @@ RETRY: return } func (s *HTTP) ConnectSSH() (err error) { + select { + case s.lockChn <- true: + default: + err = fmt.Errorf("can not connect at same time") + return + } config := ssh.ClientConfig{ - User: *s.cfg.SSHUser, - Auth: []ssh.AuthMethod{s.cfg.SSHAuthMethod}, + Timeout: time.Duration(*s.cfg.Timeout) * time.Millisecond, + User: *s.cfg.SSHUser, + Auth: []ssh.AuthMethod{s.cfg.SSHAuthMethod}, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, } + if s.sshClient != nil { + s.sshClient.Close() + } s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config) + <-s.lockChn return } func (s *HTTP) InitOutConnPool() { diff --git a/services/socks.go b/services/socks.go index ee636ad..bde3be9 100644 --- a/services/socks.go +++ b/services/socks.go @@ -1,15 +1,16 @@ package services import ( - "bytes" "crypto/tls" - "encoding/binary" "fmt" "io" "io/ioutil" "log" "net" "proxy/utils" + "proxy/utils/aes" + "proxy/utils/socks" + "runtime/debug" "time" "golang.org/x/crypto/ssh" @@ -20,6 +21,8 @@ type Socks struct { checker utils.Checker basicAuth utils.BasicAuth sshClient *ssh.Client + lockChn chan bool + udpSC utils.ServerChannel } func NewSocks() Service { @@ -27,16 +30,22 @@ func NewSocks() Service { cfg: SocksArgs{}, checker: utils.Checker{}, basicAuth: utils.BasicAuth{}, + lockChn: make(chan bool, 1), } } func (s *Socks) CheckArgs() { var err error + if *s.cfg.LocalType == "tls" { + //log.Println(*s.cfg.CertFile, *s.cfg.KeyFile) + s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) + } if *s.cfg.Parent != "" { if *s.cfg.ParentType == "" { log.Fatalf("parent type unkown,use -T ") } - if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" { + if *s.cfg.ParentType == "tls" { + log.Println(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) } if *s.cfg.ParentType == "ssh" { @@ -46,7 +55,6 @@ func (s *Socks) CheckArgs() { if *s.cfg.SSHKeyFile == "" && *s.cfg.SSHPassword == "" { log.Fatalf("ssh password or key required") } - if *s.cfg.SSHPassword != "" { s.cfg.SSHAuthMethod = ssh.Password(*s.cfg.SSHPassword) } else { @@ -76,12 +84,45 @@ func (s *Socks) InitService() { if err != nil { log.Fatalf("init service fail, ERR: %s", err) } + go func() { + //循环检查ssh网络连通性 + for { + conn, err := utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout*2) + if err != nil { + if s.sshClient != nil { + s.sshClient.Close() + if s.sshClient.Conn != nil { + s.sshClient.Conn.Close() + } + } + log.Printf("ssh offline, retrying...") + s.ConnectSSH() + } else { + conn.Close() + } + time.Sleep(time.Second * 3) + } + }() + } + if *s.cfg.ParentType == "ssh" { + log.Println("warn: socks udp not suppored for ssh") + } else { + _, port, _ := net.SplitHostPort(*s.cfg.Local) + s.udpSC = utils.NewServerChannelHost(":" + port) + err := s.udpSC.ListenUDP(s.udpCallback) + if err != nil { + log.Fatalf("init udp service fail, ERR: %s", err) + } + log.Printf("udp socks proxy on %s", s.udpSC.UDPListener.LocalAddr()) } } func (s *Socks) StopService() { if s.sshClient != nil { s.sshClient.Close() } + if s.udpSC.UDPListener != nil { + s.udpSC.UDPListener.Close() + } } func (s *Socks) Start(args interface{}) (err error) { //start() @@ -93,9 +134,9 @@ func (s *Socks) Start(args interface{}) (err error) { } sc := utils.NewServerChannelHost(*s.cfg.Local) if *s.cfg.LocalType == TYPE_TCP { - err = sc.ListenTCP(s.callback) + err = sc.ListenTCP(s.socksConnCallback) } else { - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.socksConnCallback) } if err != nil { return @@ -106,95 +147,230 @@ func (s *Socks) Start(args interface{}) (err error) { func (s *Socks) Clean() { s.StopService() } -func (s *Socks) callback(inConn net.Conn) { +func (s *Socks) UDPKey() []byte { + return s.cfg.KeyBytes[:32] +} +func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) { + rawB := b + var err error + if *s.cfg.LocalType == "tls" { + //decode b + rawB, err = goaes.Decrypt(s.UDPKey(), b) + if err != nil { + log.Printf("decrypt udp packet fail from %s", srcAddr.String()) + return + } + } + p, err := socks.ParseUDPPacket(rawB) + log.Printf("udp revecived:%v", len(p.Data())) + if err != nil { + log.Printf("parse udp packet fail, ERR:%s", err) + return + } + //log.Printf("##########udp to -> %s:%s###########", p.Host(), p.Port()) + if *s.cfg.Parent != "" { + //有上级代理,转发给上级 + if *s.cfg.ParentType == "tls" { + //encode b + rawB, err = goaes.Encrypt(s.UDPKey(), rawB) + if err != nil { + log.Printf("encrypt udp data fail to %s", *s.cfg.Parent) + return + } + } + dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Parent) + if err != nil { + log.Printf("can't resolve address: %s", err) + return + } + clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) + if err != nil { + log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) + return + } + conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*2))) + _, err = conn.Write(rawB) + log.Printf("udp request:%v", len(rawB)) + if err != nil { + log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) + conn.Close() + return + } + + //log.Printf("send udp packet to %s success", dstAddr.String()) + buf := make([]byte, 10*1024) + length, _, err := conn.ReadFromUDP(buf) + if err != nil { + log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) + conn.Close() + return + } + respBody := buf[0:length] + log.Printf("udp response:%v", len(respBody)) + //log.Printf("revecived udp packet from %s", dstAddr.String()) + if *s.cfg.ParentType == "tls" { + //decode b + respBody, err = goaes.Decrypt(s.UDPKey(), respBody) + if err != nil { + log.Printf("encrypt udp data fail to %s", *s.cfg.Parent) + conn.Close() + return + } + } + if *s.cfg.LocalType == "tls" { + d, err := goaes.Encrypt(s.UDPKey(), respBody) + if err != nil { + log.Printf("encrypt udp data fail from %s", dstAddr.String()) + conn.Close() + return + } + s.udpSC.UDPListener.WriteToUDP(d, srcAddr) + log.Printf("udp reply:%v", len(d)) + } else { + s.udpSC.UDPListener.WriteToUDP(respBody, srcAddr) + log.Printf("udp reply:%v", len(respBody)) + } + + } else { + //本地代理 + dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port())) + if err != nil { + log.Printf("can't resolve address: %s", err) + return + } + clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) + if err != nil { + log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) + return + } + conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*3))) + _, err = conn.Write(p.Data()) + log.Printf("udp send:%v", len(p.Data())) + if err != nil { + log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) + conn.Close() + return + } + //log.Printf("send udp packet to %s success", dstAddr.String()) + buf := make([]byte, 10*1024) + length, _, err := conn.ReadFromUDP(buf) + if err != nil { + log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) + conn.Close() + return + } + respBody := buf[0:length] + //封装来自真实服务器的数据,返回给访问者 + respPacket := p.NewReply(respBody) + //log.Printf("revecived udp packet from %s", dstAddr.String()) + if *s.cfg.LocalType == "tls" { + d, err := goaes.Encrypt(s.UDPKey(), respPacket) + if err != nil { + log.Printf("encrypt udp data fail from %s", dstAddr.String()) + conn.Close() + return + } + s.udpSC.UDPListener.WriteToUDP(d, srcAddr) + } else { + s.udpSC.UDPListener.WriteToUDP(respPacket, srcAddr) + } + log.Printf("udp reply:%v", len(respPacket)) + } + +} +func (s *Socks) socksConnCallback(inConn net.Conn) { defer func() { if err := recover(); err != nil { - //log.Printf("socks conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) + log.Printf("socks conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) } utils.CloseConn(&inConn) }() - var outConn net.Conn - defer utils.CloseConn(&outConn) - var b [1024]byte - n, err := inConn.Read(b[:]) - if err != nil { - if err != io.EOF { - log.Printf("read request data fail,ERR: %s", err) - } + //method select request + methodReq, err := socks.NewMethodsRequest(inConn) + if err != nil || !methodReq.Select(socks.Method_NO_AUTH) { + methodReq.Reply(socks.Method_NONE_ACCEPTABLE) + utils.CloseConn(&inConn) return } - var reqBytes = b[:n] - //log.Printf("% x", b[:n]) - - //reply - n, err = inConn.Write([]byte{0x05, 0x00}) + //method select reply + err = methodReq.Reply(socks.Method_NO_AUTH) if err != nil { log.Printf("reply answer data fail,ERR: %s", err) + utils.CloseConn(&inConn) return } - //read answer - n, err = inConn.Read(b[:]) + // log.Printf("% x", methodReq.Bytes()) + + //request detail + request, err := socks.NewRequest(inConn) if err != nil { - log.Printf("read answer data fail,ERR: %s", err) + log.Printf("read request data fail,ERR: %s", err) + utils.CloseConn(&inConn) return } - var headBytes = b[:n] - // log.Printf("% x", b[:n]) - var addr string - switch b[3] { - case 0x01: - sip := sockIP{} - if err := binary.Read(bytes.NewReader(b[4:n]), binary.BigEndian, &sip); err != nil { - log.Printf("read ip fail,ERR: %s", err) - return - } - addr = sip.toAddr() - case 0x03: - host := string(b[5 : n-2]) - var port uint16 - err = binary.Read(bytes.NewReader(b[n-2:n]), binary.BigEndian, &port) - if err != nil { - log.Printf("read domain fail,ERR: %s", err) - return - } - addr = fmt.Sprintf("%s:%d", host, port) + + switch request.CMD() { + case socks.CMD_BIND: + //bind 不支持 + request.TCPReply(socks.REP_UNKNOWN) + utils.CloseConn(&inConn) + return + case socks.CMD_CONNECT: + //tcp + s.proxyTCP(&inConn, methodReq, request) + case socks.CMD_ASSOCIATE: + //udp + s.proxyUDP(&inConn, methodReq, request) } + +} +func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) { + if *s.cfg.ParentType == "ssh" { + utils.CloseConn(inConn) + return + } + host, _, _ := net.SplitHostPort((*inConn).LocalAddr().String()) + _, port, _ := net.SplitHostPort(s.udpSC.UDPListener.LocalAddr().String()) + // log.Printf("proxy udp on %s", net.JoinHostPort(host, port)) + request.UDPReply(socks.REP_SUCCESS, net.JoinHostPort(host, port)) +} +func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) { + var outConn net.Conn + defer utils.CloseConn(&outConn) + var err interface{} useProxy := true if *s.cfg.Always { - outConn, err = s.getOutConn(reqBytes, headBytes, addr) + outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr()) } else { if *s.cfg.Parent != "" { - s.checker.Add(addr, true, "", "", nil) - useProxy, _, _ = s.checker.IsBlocked(addr) + s.checker.Add(request.Addr(), true, "", "", nil) + useProxy, _, _ = s.checker.IsBlocked(request.Addr()) if useProxy { - outConn, err = s.getOutConn(reqBytes, headBytes, addr) + outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr()) } else { - outConn, err = utils.ConnectHost(addr, *s.cfg.Timeout) + outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) } } else { - outConn, err = utils.ConnectHost(addr, *s.cfg.Timeout) + outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) } } if err != nil { log.Printf("get out conn fail,%s", err) - inConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + request.TCPReply(socks.REP_NETWOR_UNREACHABLE) return } - log.Printf("use proxy %v : %s", useProxy, addr) + log.Printf("use proxy %v : %s", useProxy, request.Addr()) - inConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + request.TCPReply(socks.REP_SUCCESS) + inAddr := (*inConn).RemoteAddr().String() + inLocalAddr := (*inConn).LocalAddr().String() - inAddr := inConn.RemoteAddr().String() - inLocalAddr := inConn.LocalAddr().String() - - log.Printf("conn %s - %s connected [%s]", inAddr, inLocalAddr, addr) - // utils.IoBind(outConn, inConn, func(err error) { - // log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, addr) - - // }, func(i int, b bool) {}, 0) + log.Printf("conn %s - %s connected [%s]", inAddr, inLocalAddr, request.Addr()) var bind = func() (err interface{}) { defer func() { if err == nil { @@ -211,17 +387,18 @@ func (s *Socks) callback(inConn net.Conn) { } } }() - _, err = io.Copy(outConn, inConn) + _, err = io.Copy(outConn, (*inConn)) }() - _, err = io.Copy(inConn, outConn) + _, err = io.Copy((*inConn), outConn) return } bind() - log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, addr) - utils.CloseConn(&inConn) + log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, request.Addr()) + utils.CloseConn(inConn) utils.CloseConn(&outConn) } -func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net.Conn, err error) { +func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn net.Conn, err interface{}) { + errchn := make(chan interface{}, 1) switch *s.cfg.ParentType { case "tls": fallthrough @@ -238,7 +415,7 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net } var buf = make([]byte, 1024) //var n int - _, err = outConn.Write(reqBytes) + _, err = outConn.Write(methodBytes) if err != nil { return } @@ -249,7 +426,7 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net //resp := buf[:n] //log.Printf("resp:%v", resp) - outConn.Write(headBytes) + outConn.Write(reqBytes) _, err = outConn.Read(buf) if err != nil { return @@ -264,10 +441,19 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net if tryCount >= maxTryCount { return } - outConn, err = s.sshClient.Dial("tcp", host) + go func() { + defer func() { + if err == nil { + errchn <- recover() + } else { + errchn <- nil + } + }() + outConn, err = s.sshClient.Dial("tcp", host) + }() + err = <-errchn if err != nil { log.Printf("connect ssh fail, ERR: %s, retrying...", err) - s.sshClient.Close() e := s.ConnectSSH() if e == nil { tryCount++ @@ -282,6 +468,12 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net return } func (s *Socks) ConnectSSH() (err error) { + select { + case s.lockChn <- true: + default: + err = fmt.Errorf("can not connect at same time") + return + } config := ssh.ClientConfig{ Timeout: time.Duration(*s.cfg.Timeout) * time.Millisecond, User: *s.cfg.SSHUser, @@ -290,15 +482,10 @@ func (s *Socks) ConnectSSH() (err error) { return nil }, } + if s.sshClient != nil { + s.sshClient.Close() + } s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config) + <-s.lockChn return } - -type sockIP struct { - A, B, C, D byte - PORT uint16 -} - -func (ip sockIP) toAddr() string { - return fmt.Sprintf("%d.%d.%d.%d:%d", ip.A, ip.B, ip.C, ip.D, ip.PORT) -} diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index 683ca91..ab9bd57 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -13,7 +13,6 @@ import ( type ServerConn struct { ClientLocalAddr string //tcp:2.2.22:333@ID Conn *net.Conn - //Conn *utils.HeartbeatReadWriter } type TunnelBridge struct { cfg TunnelBridgeArgs @@ -78,7 +77,6 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { return } key = string(_key) - //log.Printf("conn key %s", key) if connType != CONN_CONTROL { var IDLength uint16 @@ -117,13 +115,8 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { switch connType { case CONN_SERVER: - // hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { - // log.Printf("%s conn %s from server released", key, ID) - // s.serverConns.Remove(ID) - // }) addr := clientLocalAddr + "@" + ID s.serverConns.Set(ID, ServerConn{ - //Conn: &hb, Conn: &inConn, ClientLocalAddr: addr, }) @@ -134,7 +127,9 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { time.Sleep(time.Second * 3) continue } + (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) _, err := (*item.(*net.Conn)).Write([]byte(addr)) + (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) if err != nil { log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) time.Sleep(time.Second * 3) @@ -151,33 +146,36 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { return } serverConn := serverConnItem.(ServerConn).Conn - // hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) { - // log.Printf("%s conn %s from client released", key, ID) - // hw.Close() - // }) utils.IoBind(*serverConn, inConn, func(err error) { - // utils.IoBind(serverConn, inConn, func(isSrcErr bool, err error) { - //serverConn.Close() + (*serverConn).Close() utils.CloseConn(&inConn) - // hw.Close() s.serverConns.Remove(ID) log.Printf("conn %s released", ID) }, func(i int, b bool) {}, 0) log.Printf("conn %s created", ID) + case CONN_CONTROL: if s.clientControlConns.Has(key) { item, _ := s.clientControlConns.Get(key) - //(*item.(*utils.HeartbeatReadWriter)).Close() (*item.(*net.Conn)).Close() } - // hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { - // log.Printf("client %s disconnected", key) - // s.clientControlConns.Remove(key) - // }) - // s.clientControlConns.Set(key, &hb) s.clientControlConns.Set(key, &inConn) log.Printf("set client %s control conn", key) + go func() { + for { + var b = make([]byte, 1) + _, err = inConn.Read(b) + if err != nil { + inConn.Close() + s.serverConns.Remove(ID) + log.Printf("%s control conn from client released", key) + break + } else { + //log.Printf("%s heartbeat from client", key) + } + } + }() } }) if err != nil { diff --git a/services/tunnel_client.go b/services/tunnel_client.go index ff2e0d6..a77ad04 100644 --- a/services/tunnel_client.go +++ b/services/tunnel_client.go @@ -46,27 +46,33 @@ func (s *TunnelClient) Start(args interface{}) (err error) { for { ctrlConn, err := s.GetInConn(CONN_CONTROL, "") if err != nil { - log.Printf("control connection err: %s", err) + log.Printf("control connection err: %s, retrying...", err) time.Sleep(time.Second * 3) utils.CloseConn(&ctrlConn) continue } - // rw := utils.NewHeartbeatReadWriter(&ctrlConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { - // log.Printf("ctrlConn err %s", err) - // utils.CloseConn(&ctrlConn) - // }) + go func() { + for { + ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) + _, err = ctrlConn.Write([]byte{0x00}) + ctrlConn.SetWriteDeadline(time.Time{}) + if err != nil { + utils.CloseConn(&ctrlConn) + log.Printf("ctrlConn err %s", err) + break + } + time.Sleep(time.Second * 3) + } + }() for { signal := make([]byte, 50) - // n, err := rw.Read(signal) n, err := ctrlConn.Read(signal) if err != nil { utils.CloseConn(&ctrlConn) - log.Printf("read connection signal err: %s", err) + log.Printf("read connection signal err: %s, retrying...", err) break } addr := string(signal[:n]) - // log.Printf("n:%d addr:%s err:%s", n, addr, err) - // os.Exit(0) log.Printf("signal revecived:%s", addr) protocol := addr[:3] atIndex := strings.Index(addr, "@") diff --git a/services/tunnel_server.go b/services/tunnel_server.go index 08d85f8..f581a86 100644 --- a/services/tunnel_server.go +++ b/services/tunnel_server.go @@ -64,12 +64,15 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) { CertBytes: s.cfg.CertBytes, KeyBytes: s.cfg.KeyBytes, Parent: s.cfg.Parent, + CertFile: s.cfg.CertFile, + KeyFile: s.cfg.KeyFile, Local: &local, IsUDP: &IsUDP, Remote: &remote, Key: &KEY, Timeout: s.cfg.Timeout, }) + if err != nil { return } diff --git a/utils/aes/aes.go b/utils/aes/aes.go new file mode 100644 index 0000000..3d4536e --- /dev/null +++ b/utils/aes/aes.go @@ -0,0 +1,84 @@ +// Playbook - http://play.golang.org/p/3wFl4lacjX + +package goaes + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "io" + "strings" +) + +func addBase64Padding(value string) string { + m := len(value) % 4 + if m != 0 { + value += strings.Repeat("=", 4-m) + } + + return value +} + +func removeBase64Padding(value string) string { + return strings.Replace(value, "=", "", -1) +} + +func Pad(src []byte) []byte { + padding := aes.BlockSize - len(src)%aes.BlockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padtext...) +} + +func Unpad(src []byte) ([]byte, error) { + length := len(src) + unpadding := int(src[length-1]) + + if unpadding > length { + return nil, errors.New("unpad error. This could happen when incorrect encryption key is used") + } + + return src[:(length - unpadding)], nil +} + +func Encrypt(key []byte, text []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + msg := Pad(text) + ciphertext := make([]byte, aes.BlockSize+len(msg)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + cfb := cipher.NewCFBEncrypter(block, iv) + cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(msg)) + + return ciphertext, nil +} + +func Decrypt(key []byte, text []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if (len(text) % aes.BlockSize) != 0 { + return nil, errors.New("blocksize must be multipe of decoded message length") + } + iv := text[:aes.BlockSize] + msg := text[aes.BlockSize:] + + cfb := cipher.NewCFBDecrypter(block, iv) + cfb.XORKeyStream(msg, msg) + + unpadMsg, err := Unpad(msg) + if err != nil { + return nil, err + } + + return unpadMsg, nil +} diff --git a/utils/functions.go b/utils/functions.go index 753c738..0382c1f 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -194,6 +194,9 @@ func HTTPGet(URL string, timeout int) (err error) { } func CloseConn(conn *net.Conn) { + defer func() { + _ = recover() + }() if conn != nil && *conn != nil { (*conn).SetDeadline(time.Now().Add(time.Millisecond)) (*conn).Close() @@ -312,6 +315,24 @@ func Uniqueid() string { s := fmt.Sprintf("%d", src.Int63()) return s[len(s)-5:len(s)-1] + fmt.Sprintf("%d", uint64(time.Now().UnixNano()))[8:] } +func SubStr(str string, start, end int) string { + if len(str) == 0 { + return "" + } + if end >= len(str) { + end = len(str) - 1 + } + return str[start:end] +} +func SubBytes(bytes []byte, start, end int) []byte { + if len(bytes) == 0 { + return []byte{} + } + if end >= len(bytes) { + end = len(bytes) - 1 + } + return bytes[start:end] +} func TlsBytes(cert, key string) (certBytes, keyBytes []byte) { certBytes, err := ioutil.ReadFile(cert) if err != nil { diff --git a/utils/socks/structs.go b/utils/socks/structs.go new file mode 100644 index 0000000..4e04b54 --- /dev/null +++ b/utils/socks/structs.go @@ -0,0 +1,260 @@ +package socks + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" +) + +const ( + Method_NO_AUTH = uint8(0x00) + Method_GSSAPI = uint8(0x01) + Method_USER_PASS = uint8(0x02) + Method_IANA = uint8(0x7F) + Method_RESVERVE = uint8(0x80) + Method_NONE_ACCEPTABLE = uint8(0xFF) + VERSION_V5 = uint8(0x05) + CMD_CONNECT = uint8(0x01) + CMD_BIND = uint8(0x02) + CMD_ASSOCIATE = uint8(0x03) + ATYP_IPV4 = uint8(0x01) + ATYP_DOMAIN = uint8(0x03) + ATYP_IPV6 = uint8(0x04) + REP_SUCCESS = uint8(0x00) + REP_REQ_FAIL = uint8(0x01) + REP_RULE_FORBIDDEN = uint8(0x02) + REP_NETWOR_UNREACHABLE = uint8(0x03) + REP_HOST_UNREACHABLE = uint8(0x04) + REP_CONNECTION_REFUSED = uint8(0x05) + REP_TTL_TIMEOUT = uint8(0x06) + REP_CMD_UNSUPPORTED = uint8(0x07) + REP_ATYP_UNSUPPORTED = uint8(0x08) + REP_UNKNOWN = uint8(0x09) + RSV = uint8(0x00) +) + +var ( + ZERO_IP = []byte{0x00, 0x00, 0x00, 0x00} + ZERO_PORT = []byte{0x00, 0x00} +) + +type Request struct { + ver uint8 + cmd uint8 + reserve uint8 + addressType uint8 + dstAddr string + dstPort string + dstHost string + bytes []byte + rw io.ReadWriter +} + +func NewRequest(rw io.ReadWriter) (req Request, err interface{}) { + var b [1024]byte + var n int + req = Request{rw: rw} + n, err = rw.Read(b[:]) + if err != nil { + err = fmt.Errorf("read req data fail,ERR: %s", err) + return + } + req.ver = uint8(b[0]) + req.cmd = uint8(b[1]) + req.reserve = uint8(b[2]) + req.addressType = uint8(b[3]) + + if b[0] != 0x5 { + err = fmt.Errorf("sosck version supported") + req.TCPReply(REP_REQ_FAIL) + return + } + switch b[3] { + case 0x01: //IP V4 + req.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String() + case 0x03: //域名 + req.dstHost = string(b[5 : n-2]) //b[4]表示域名的长度 + case 0x04: //IP V6 + req.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String() + } + req.dstPort = strconv.Itoa(int(b[n-2])<<8 | int(b[n-1])) + req.dstAddr = net.JoinHostPort(req.dstHost, req.dstPort) + req.bytes = b[:n] + return +} +func (s *Request) Bytes() []byte { + return s.bytes +} +func (s *Request) Addr() string { + return s.dstAddr +} +func (s *Request) Host() string { + return s.dstHost +} +func (s *Request) Port() string { + return s.dstPort +} +func (s *Request) AType() uint8 { + return s.addressType +} +func (s *Request) CMD() uint8 { + return s.cmd +} + +func (s *Request) TCPReply(rep uint8) (err error) { + _, err = s.rw.Write(s.NewReply(rep, "0.0.0.0:0")) + return +} +func (s *Request) UDPReply(rep uint8, addr string) (err error) { + _, err = s.rw.Write(s.NewReply(rep, addr)) + return +} +func (s *Request) NewReply(rep uint8, addr string) []byte { + var response bytes.Buffer + host, port, _ := net.SplitHostPort(addr) + ip := net.ParseIP(host) + ipb := ip.To4() + atyp := ATYP_IPV4 + ipv6 := ip.To16() + zeroiIPv6 := fmt.Sprintf("%d%d%d%d%d%d%d%d%d%d%d%d", + ipv6[0], ipv6[1], ipv6[2], ipv6[3], + ipv6[4], ipv6[5], ipv6[6], ipv6[7], + ipv6[8], ipv6[9], ipv6[10], ipv6[11], + ) + if ipv6 != nil && "0000000000255255" != zeroiIPv6 { + atyp = ATYP_IPV6 + ipb = ip.To16() + } + porti, _ := strconv.Atoi(port) + portb := make([]byte, 2) + binary.BigEndian.PutUint16(portb, uint16(porti)) + // log.Printf("atyp : %v", atyp) + // log.Printf("ip : %v", []byte(ip)) + response.WriteByte(VERSION_V5) + response.WriteByte(rep) + response.WriteByte(RSV) + response.WriteByte(atyp) + response.Write(ipb) + response.Write(portb) + return response.Bytes() +} + +type MethodsRequest struct { + ver uint8 + methodsCount uint8 + methods []uint8 + bytes []byte + rw *io.ReadWriter +} + +func NewMethodsRequest(r io.ReadWriter) (s MethodsRequest, err interface{}) { + defer func() { + if err == nil { + err = recover() + } + }() + s = MethodsRequest{} + s.rw = &r + var buf = make([]byte, 300) + var n int + n, err = r.Read(buf) + if err != nil { + return + } + if buf[0] != 0x05 { + err = fmt.Errorf("socks version not supported") + return + } + if n != int(buf[1])+int(2) { + err = fmt.Errorf("socks methods data length error") + return + } + + s.ver = buf[0] + s.methodsCount = buf[1] + s.methods = buf[2:n] + s.bytes = buf[:n] + return +} +func (s *MethodsRequest) Version() uint8 { + return s.ver +} +func (s *MethodsRequest) MethodsCount() uint8 { + return s.methodsCount +} +func (s *MethodsRequest) Select(method uint8) bool { + for _, m := range s.methods { + if m == method { + return true + } + } + return false +} +func (s *MethodsRequest) Reply(method uint8) (err error) { + _, err = (*s.rw).Write([]byte{byte(VERSION_V5), byte(method)}) + return +} +func (s *MethodsRequest) Bytes() []byte { + return s.bytes +} + +type UDPPacket struct { + rsv uint16 + frag uint8 + atype uint8 + dstHost string + dstPort string + data []byte + header []byte + bytes []byte +} + +func ParseUDPPacket(b []byte) (p UDPPacket, err error) { + p = UDPPacket{} + p.frag = uint8(b[2]) + p.bytes = b + if p.frag != 0 { + err = fmt.Errorf("FRAG only support for 0 , %v ,%v", p.frag, b[:4]) + return + } + portIndex := 0 + p.atype = b[3] + switch p.atype { + case ATYP_IPV4: //IP V4 + p.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String() + portIndex = 8 + case ATYP_DOMAIN: //域名 + domainLen := uint8(b[4]) + p.dstHost = string(b[5 : 5+domainLen]) //b[4]表示域名的长度 + portIndex = int(5 + domainLen) + case ATYP_IPV6: //IP V6 + p.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String() + portIndex = 20 + } + p.dstPort = strconv.Itoa(int(b[portIndex])<<8 | int(b[portIndex+1])) + p.data = b[portIndex+2:] + p.header = b[:portIndex+2] + return +} +func (s *UDPPacket) Header() []byte { + return s.header +} +func (s *UDPPacket) NewReply(data []byte) []byte { + var buf bytes.Buffer + buf.Write(s.header) + buf.Write(data) + return buf.Bytes() +} +func (s *UDPPacket) Host() string { + return s.dstHost +} + +func (s *UDPPacket) Port() string { + return s.dstPort +} +func (s *UDPPacket) Data() []byte { + return s.data +} diff --git a/utils/structs.go b/utils/structs.go index 57303e7..7f67635 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -256,13 +256,13 @@ func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth * req.HeadBuf = buf[:len] index := bytes.IndexByte(req.HeadBuf, '\n') if index == -1 { - err = fmt.Errorf("http decoder data line err:%s", string(req.HeadBuf)[:50]) + err = fmt.Errorf("http decoder data line err:%s", SubStr(string(req.HeadBuf), 0, 50)) CloseConn(inConn) return } fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL) if req.Method == "" || req.hostOrURL == "" { - err = fmt.Errorf("http decoder data err:%s", string(req.HeadBuf)[:50]) + err = fmt.Errorf("http decoder data err:%s", SubStr(string(req.HeadBuf), 0, 50)) CloseConn(inConn) return }