package socks import ( "crypto/tls" "fmt" "io" "io/ioutil" logger "log" "net" "runtime/debug" "strconv" "strings" "time" "github.com/snail007/goproxy/services" "github.com/snail007/goproxy/services/kcpcfg" "github.com/snail007/goproxy/utils" "github.com/snail007/goproxy/utils/aes" "github.com/snail007/goproxy/utils/conncrypt" "github.com/snail007/goproxy/utils/socks" "golang.org/x/crypto/ssh" ) type SocksArgs struct { Parent *string ParentType *string Local *string LocalType *string CertFile *string KeyFile *string CaCertFile *string CaCertBytes []byte CertBytes []byte KeyBytes []byte SSHKeyFile *string SSHKeyFileSalt *string SSHPassword *string SSHUser *string SSHKeyBytes []byte SSHAuthMethod ssh.AuthMethod Timeout *int Always *bool Interval *int Blocked *string Direct *string AuthFile *string Auth *[]string AuthURL *string AuthURLOkCode *int AuthURLTimeout *int AuthURLRetry *int KCP kcpcfg.KCPConfigArgs UDPParent *string UDPLocal *string LocalIPS *[]string DNSAddress *string DNSTTL *int LocalKey *string ParentKey *string LocalCompress *bool ParentCompress *bool } type Socks struct { cfg SocksArgs checker utils.Checker basicAuth utils.BasicAuth sshClient *ssh.Client lockChn chan bool udpSC utils.ServerChannel sc *utils.ServerChannel domainResolver utils.DomainResolver isStop bool userConns utils.ConcurrentMap log *logger.Logger udpRelatedPacketConns utils.ConcurrentMap } func NewSocks() services.Service { return &Socks{ cfg: SocksArgs{}, checker: utils.Checker{}, basicAuth: utils.BasicAuth{}, lockChn: make(chan bool, 1), isStop: false, userConns: utils.NewConcurrentMap(), udpRelatedPacketConns: utils.NewConcurrentMap(), } } func (s *Socks) CheckArgs() (err error) { if *s.cfg.LocalType == "tls" || (*s.cfg.Parent != "" && *s.cfg.ParentType == "tls") { s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) if err != nil { return } if *s.cfg.CaCertFile != "" { s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile) if err != nil { err = fmt.Errorf("read ca file error,ERR:%s", err) return } } } if *s.cfg.Parent != "" { if *s.cfg.ParentType == "" { err = fmt.Errorf("parent type unkown,use -T ") return } host, _, e := net.SplitHostPort(*s.cfg.Parent) if e != nil { err = fmt.Errorf("parent format error : %s", e) return } if *s.cfg.UDPParent == "" { *s.cfg.UDPParent = net.JoinHostPort(host, "33090") } if strings.HasPrefix(*s.cfg.UDPParent, ":") { *s.cfg.UDPParent = net.JoinHostPort(host, strings.TrimLeft(*s.cfg.UDPParent, ":")) } if *s.cfg.ParentType == "ssh" { if *s.cfg.SSHUser == "" { err = fmt.Errorf("ssh user required") return } if *s.cfg.SSHKeyFile == "" && *s.cfg.SSHPassword == "" { err = fmt.Errorf("ssh password or key required") return } if *s.cfg.SSHPassword != "" { s.cfg.SSHAuthMethod = ssh.Password(*s.cfg.SSHPassword) } else { var SSHSigner ssh.Signer s.cfg.SSHKeyBytes, err = ioutil.ReadFile(*s.cfg.SSHKeyFile) if err != nil { err = fmt.Errorf("read key file ERR: %s", err) return } if *s.cfg.SSHKeyFileSalt != "" { SSHSigner, err = ssh.ParsePrivateKeyWithPassphrase(s.cfg.SSHKeyBytes, []byte(*s.cfg.SSHKeyFileSalt)) } else { SSHSigner, err = ssh.ParsePrivateKey(s.cfg.SSHKeyBytes) } if err != nil { err = fmt.Errorf("parse ssh private key fail,ERR: %s", err) return } s.cfg.SSHAuthMethod = ssh.PublicKeys(SSHSigner) } } } return } func (s *Socks) InitService() (err error) { s.InitBasicAuth() if *s.cfg.DNSAddress != "" { (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log) } s.checker = utils.NewChecker(*s.cfg.Timeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct, s.log) if *s.cfg.ParentType == "ssh" { e := s.ConnectSSH() if e != nil { err = fmt.Errorf("init service fail, ERR: %s", e) return } go func() { //循环检查ssh网络连通性 for { if s.isStop { return } conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2) if err == nil { conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) _, err = conn.Write([]byte{0}) conn.SetDeadline(time.Time{}) } if err != nil { if s.sshClient != nil { s.sshClient.Close() } s.log.Printf("ssh offline, retrying...") s.ConnectSSH() } else { conn.Close() } time.Sleep(time.Second * 3) } }() } if *s.cfg.ParentType == "ssh" { s.log.Printf("warn: socks udp not suppored for ssh") } else { s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal, s.log) e := s.udpSC.ListenUDP(s.udpCallback) if e != nil { err = fmt.Errorf("init udp service fail, ERR: %s", e) return } s.log.Printf("udp socks proxy on %s", s.udpSC.UDPListener.LocalAddr()) } return } func (s *Socks) StopService() { defer func() { e := recover() if e != nil { s.log.Printf("stop socks service crashed,%s", e) } else { s.log.Printf("service socks stoped") } }() s.isStop = true s.checker.Stop() if s.sshClient != nil { s.sshClient.Close() } if s.udpSC.UDPListener != nil { s.udpSC.UDPListener.Close() } if s.sc != nil && (*s.sc).Listener != nil { (*(*s.sc).Listener).Close() } for _, c := range s.userConns.Items() { (*c.(*net.Conn)).Close() } for _, c := range s.udpRelatedPacketConns.Items() { (*c.(*net.UDPConn)).Close() } } func (s *Socks) Start(args interface{}, log *logger.Logger) (err error) { s.log = log //start() s.cfg = args.(SocksArgs) if err = s.CheckArgs(); err != nil { return } if err = s.InitService(); err != nil { s.InitService() } if *s.cfg.Parent != "" { s.log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent) } if *s.cfg.UDPParent != "" { s.log.Printf("use socks udp parent %s", *s.cfg.UDPParent) } sc := utils.NewServerChannelHost(*s.cfg.Local, s.log) if *s.cfg.LocalType == "tcp" { err = sc.ListenTCP(s.socksConnCallback) } else if *s.cfg.LocalType == "tls" { err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.socksConnCallback) } else if *s.cfg.LocalType == "kcp" { err = sc.ListenKCP(s.cfg.KCP, s.socksConnCallback, s.log) } if err != nil { return } s.sc = &sc s.log.Printf("%s socks proxy on %s", *s.cfg.LocalType, (*sc.Listener).Addr()) return } func (s *Socks) Clean() { s.StopService() } func (s *Socks) UDPKey() []byte { return s.cfg.KeyBytes[:32] } func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) { rawB := b var err error if *s.cfg.LocalType == "tls" { //decode b rawB, err = goaes.Decrypt(s.UDPKey(), b) if err != nil { s.log.Printf("decrypt udp packet fail from %s", srcAddr.String()) return } } p, err := socks.ParseUDPPacket(rawB) s.log.Printf("udp revecived:%v", len(p.Data())) if err != nil { s.log.Printf("parse udp packet fail, ERR:%s", err) return } //防止死循环 if s.IsDeadLoop((*localAddr).String(), p.Host()) { s.log.Printf("dead loop detected , %s", p.Host()) return } //s.log.Printf("##########udp to -> %s:%s###########", p.Host(), p.Port()) if *s.cfg.Parent != "" { //有上级代理,转发给上级 if *s.cfg.ParentType == "tls" { //encode b rawB, err = goaes.Encrypt(s.UDPKey(), rawB) if err != nil { s.log.Printf("encrypt udp data fail to %s", *s.cfg.Parent) return } } parent := *s.cfg.UDPParent if parent == "" { parent = *s.cfg.Parent } dstAddr, err := net.ResolveUDPAddr("udp", s.Resolve(parent)) if err != nil { s.log.Printf("can't resolve address: %s", err) return } clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) if err != nil { s.log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) return } conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*5))) _, err = conn.Write(rawB) conn.SetDeadline(time.Time{}) s.log.Printf("udp request:%v", len(rawB)) if err != nil { s.log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) conn.Close() return } //s.log.Printf("send udp packet to %s success", dstAddr.String()) buf := make([]byte, 10*1024) conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) length, _, err := conn.ReadFromUDP(buf) conn.SetDeadline(time.Time{}) if err != nil { s.log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) conn.Close() return } respBody := buf[0:length] s.log.Printf("udp response:%v", len(respBody)) //s.log.Printf("revecived udp packet from %s", dstAddr.String()) if *s.cfg.ParentType == "tls" { //decode b respBody, err = goaes.Decrypt(s.UDPKey(), respBody) if err != nil { s.log.Printf("encrypt udp data fail to %s", *s.cfg.Parent) conn.Close() return } } if *s.cfg.LocalType == "tls" { d, err := goaes.Encrypt(s.UDPKey(), respBody) if err != nil { s.log.Printf("encrypt udp data fail from %s", dstAddr.String()) conn.Close() return } s.udpSC.UDPListener.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) s.udpSC.UDPListener.WriteToUDP(d, srcAddr) s.udpSC.UDPListener.SetDeadline(time.Time{}) s.log.Printf("udp reply:%v", len(d)) d = nil } else { s.udpSC.UDPListener.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) s.udpSC.UDPListener.WriteToUDP(respBody, srcAddr) s.udpSC.UDPListener.SetDeadline(time.Time{}) s.log.Printf("udp reply:%v", len(respBody)) } } else { //本地代理 dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(s.Resolve(p.Host()), p.Port())) if err != nil { s.log.Printf("can't resolve address: %s", err) return } clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) if err != nil { s.log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) return } conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*3))) _, err = conn.Write(p.Data()) conn.SetDeadline(time.Time{}) s.log.Printf("udp send:%v", len(p.Data())) if err != nil { s.log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) conn.Close() return } //s.log.Printf("send udp packet to %s success", dstAddr.String()) buf := make([]byte, 10*1024) conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) length, _, err := conn.ReadFromUDP(buf) conn.SetDeadline(time.Time{}) if err != nil { s.log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) conn.Close() return } respBody := buf[0:length] //封装来自真实服务器的数据,返回给访问者 respPacket := p.NewReply(respBody) //s.log.Printf("revecived udp packet from %s", dstAddr.String()) if *s.cfg.LocalType == "tls" { d, err := goaes.Encrypt(s.UDPKey(), respPacket) if err != nil { s.log.Printf("encrypt udp data fail from %s", dstAddr.String()) conn.Close() return } s.udpSC.UDPListener.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) s.udpSC.UDPListener.WriteToUDP(d, srcAddr) s.udpSC.UDPListener.SetDeadline(time.Time{}) d = nil } else { s.udpSC.UDPListener.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) s.udpSC.UDPListener.WriteToUDP(respPacket, srcAddr) s.udpSC.UDPListener.SetDeadline(time.Time{}) } s.log.Printf("udp reply:%v", len(respPacket)) } } func (s *Socks) socksConnCallback(inConn net.Conn) { defer func() { if err := recover(); err != nil { s.log.Printf("socks conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) inConn.Close() } }() if *s.cfg.LocalCompress { inConn = utils.NewCompConn(inConn) } if *s.cfg.LocalKey != "" { inConn = conncrypt.New(inConn, &conncrypt.Config{ Password: *s.cfg.LocalKey, }) } //协商开始 //method select request inConn.SetReadDeadline(time.Now().Add(time.Second * 3)) methodReq, err := socks.NewMethodsRequest(inConn) inConn.SetReadDeadline(time.Time{}) if err != nil { methodReq.Reply(socks.Method_NONE_ACCEPTABLE) utils.CloseConn(&inConn) if err != io.EOF { s.log.Printf("new methods request fail,ERR: %s", err) } return } if !s.IsBasicAuth() { if !methodReq.Select(socks.Method_NO_AUTH) { methodReq.Reply(socks.Method_NONE_ACCEPTABLE) utils.CloseConn(&inConn) s.log.Printf("none method found : Method_NO_AUTH") return } //method select reply err = methodReq.Reply(socks.Method_NO_AUTH) if err != nil { s.log.Printf("reply answer data fail,ERR: %s", err) utils.CloseConn(&inConn) return } // s.log.Printf("% x", methodReq.Bytes()) } else { //auth if !methodReq.Select(socks.Method_USER_PASS) { methodReq.Reply(socks.Method_NONE_ACCEPTABLE) utils.CloseConn(&inConn) s.log.Printf("none method found : Method_USER_PASS") return } //method reply need auth err = methodReq.Reply(socks.Method_USER_PASS) if err != nil { s.log.Printf("reply answer data fail,ERR: %s", err) utils.CloseConn(&inConn) return } //read auth buf := make([]byte, 500) inConn.SetReadDeadline(time.Now().Add(time.Second * 3)) n, err := inConn.Read(buf) inConn.SetReadDeadline(time.Time{}) if err != nil { utils.CloseConn(&inConn) return } r := buf[:n] user := string(r[2 : r[1]+2]) pass := string(r[2+r[1]+1:]) //s.log.Printf("user:%s,pass:%s", user, pass) //auth _addr := strings.Split(inConn.RemoteAddr().String(), ":") if s.basicAuth.CheckUserPass(user, pass, _addr[0], "") { inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) inConn.Write([]byte{0x01, 0x00}) inConn.SetDeadline(time.Time{}) } else { inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) inConn.Write([]byte{0x01, 0x01}) inConn.SetDeadline(time.Time{}) utils.CloseConn(&inConn) return } } //request detail request, err := socks.NewRequest(inConn) if err != nil { s.log.Printf("read request data fail,ERR: %s", err) utils.CloseConn(&inConn) return } //协商结束 switch request.CMD() { case socks.CMD_BIND: //bind 不支持 request.TCPReply(socks.REP_UNKNOWN) utils.CloseConn(&inConn) return case socks.CMD_CONNECT: //tcp s.proxyTCP(&inConn, methodReq, request) case socks.CMD_ASSOCIATE: //udp s.proxyUDP(&inConn, methodReq, request) } } func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) { if *s.cfg.ParentType == "ssh" { utils.CloseConn(inConn) return } inconnRemoteAddr := (*inConn).RemoteAddr().String() localAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} udpListener, err := net.ListenUDP("udp", localAddr) if err != nil { (*inConn).Close() udpListener.Close() s.log.Printf("udp bind fail , %s", err) return } host, _, _ := net.SplitHostPort((*inConn).LocalAddr().String()) _, port, _ := net.SplitHostPort(udpListener.LocalAddr().String()) s.log.Printf("proxy udp on %s , for %s", udpListener.LocalAddr(), inconnRemoteAddr) request.UDPReply(socks.REP_SUCCESS, net.JoinHostPort(host, port)) s.userConns.Set(inconnRemoteAddr, inConn) var ( outUDPConn *net.UDPConn outconn net.Conn outconnLocalAddr string isClosedErr = func(err error) bool { return err != nil && strings.Contains(err.Error(), "use of closed network connection") } ) var clean = func(msg, err string) { raddr := "" if outUDPConn != nil { raddr = outUDPConn.RemoteAddr().String() outUDPConn.Close() } if msg != "" { if raddr != "" { s.log.Printf("%s , %s , %s -> %s", msg, err, inconnRemoteAddr, raddr) } else { s.log.Printf("%s , %s , from : %s", msg, err, inconnRemoteAddr) } } (*inConn).Close() udpListener.Close() s.userConns.Remove(inconnRemoteAddr) if outconn != nil { outconn.Close() } if outconnLocalAddr != "" { s.userConns.Remove(outconnLocalAddr) } } defer clean("", "") go func() { buf := make([]byte, 1) if _, err := (*inConn).Read(buf); err != nil { clean("udp related tcp conn disconnected with read", err.Error()) } }() go func() { for { if _, err := (*inConn).Write([]byte{0x00}); err != nil { clean("udp related tcp conn disconnected with write", err.Error()) return } time.Sleep(time.Second * 5) } }() if *s.cfg.Parent != "" { outconn, err := s.getOutConn(nil, nil, "", false) if err != nil { clean("connnect fail", fmt.Sprintf("%s", err)) return } client := socks.NewClientConn(&outconn, "udp", "", time.Millisecond*time.Duration(*s.cfg.Timeout), nil, nil) if err = client.Handshake(); err != nil { clean("handshake fail", fmt.Sprintf("%s", err)) return } //outconnRemoteAddr := outconn.RemoteAddr().String() outconnLocalAddr = outconn.LocalAddr().String() s.userConns.Set(outconnLocalAddr, &outconn) go func() { buf := make([]byte, 1) if _, err := outconn.Read(buf); err != nil { clean("udp parent tcp conn disconnected", fmt.Sprintf("%s", err)) } }() //forward to parent udp s.log.Printf("parent udp address %s", client.UDPAddr) } else { for { buf := utils.LeakyBuffer.Get() defer utils.LeakyBuffer.Put(buf) n, srcAddr, err := udpListener.ReadFromUDP(buf) if err != nil { s.log.Printf("udp listener read fail, %s", err.Error()) if isClosedErr(err) { return } continue } p := socks.NewPacketUDP() err = p.Parse(buf[:n]) if err != nil { s.log.Printf("udp listener parse packet fail, %s", err.Error()) continue } port, _ := strconv.Atoi(p.Port()) destAddr := &net.UDPAddr{IP: net.ParseIP(p.Host()), Port: port} if v, ok := s.udpRelatedPacketConns.Get(srcAddr.String()); !ok { outUDPConn, err = net.DialUDP("udp", localAddr, destAddr) if err != nil { s.log.Printf("create out udp conn fail , %s , from : %s", err, srcAddr) continue } s.udpRelatedPacketConns.Set(srcAddr.String(), outUDPConn) go func() { defer s.udpRelatedPacketConns.Remove(srcAddr.String()) //bind buf := utils.LeakyBuffer.Get() defer utils.LeakyBuffer.Put(buf) for { n, err := outUDPConn.Read(buf) if err != nil { s.log.Printf("read out udp data fail , %s , from : %s", err, srcAddr) if isClosedErr(err) { return } continue } rp := socks.NewPacketUDP() rp.Build(srcAddr.String(), buf[:n]) d := rp.Bytes() _, err = udpListener.WriteTo(d, srcAddr) if err != nil { s.udpRelatedPacketConns.Remove(srcAddr.String()) s.log.Printf("write out data to local fail , %s , from : %s", err, srcAddr) if isClosedErr(err) { return } continue } else { s.log.Printf("send udp data to local success , len %d, for : %s", len(d), srcAddr) } } }() } else { outUDPConn = v.(*net.UDPConn) } _, err = outUDPConn.Write(p.Data()) if err != nil { if isClosedErr(err) { return } s.log.Printf("send out udp data fail , %s , from : %s", err, srcAddr) continue } else { s.log.Printf("send udp data to remote success , len %d, for : %s", len(p.Data()), srcAddr) } } } } func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) { var outConn net.Conn var err interface{} useProxy := true tryCount := 0 maxTryCount := 5 //防止死循环 if s.IsDeadLoop((*inConn).LocalAddr().String(), request.Host()) { utils.CloseConn(inConn) s.log.Printf("dead loop detected , %s", request.Host()) utils.CloseConn(inConn) return } for { if s.isStop { return } if *s.cfg.Always { outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr(), true) } else { if *s.cfg.Parent != "" { host, _, _ := net.SplitHostPort(request.Addr()) useProxy := false if utils.IsIternalIP(host, *s.cfg.Always) { useProxy = false } else { var isInMap bool useProxy, isInMap, _, _ = s.checker.IsBlocked(request.Addr()) if !isInMap { s.checker.Add(request.Addr(), s.Resolve(request.Addr())) } } if useProxy { outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr(), true) } else { outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout) } } else { outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout) useProxy = false } } tryCount++ if err == nil || tryCount > maxTryCount || *s.cfg.Parent == "" { break } else { s.log.Printf("get out conn fail,%s,retrying...", err) time.Sleep(time.Second * 2) } } if err != nil { s.log.Printf("get out conn fail,%s", err) request.TCPReply(socks.REP_NETWOR_UNREACHABLE) return } s.log.Printf("use proxy %v : %s", useProxy, request.Addr()) request.TCPReply(socks.REP_SUCCESS) inAddr := (*inConn).RemoteAddr().String() //inLocalAddr := (*inConn).LocalAddr().String() s.log.Printf("conn %s - %s connected", inAddr, request.Addr()) utils.IoBind(*inConn, outConn, func(err interface{}) { s.log.Printf("conn %s - %s released", inAddr, request.Addr()) s.userConns.Remove(inAddr) }, s.log) if c, ok := s.userConns.Get(inAddr); ok { (*c.(*net.Conn)).Close() s.userConns.Remove(inAddr) } s.userConns.Set(inAddr, inConn) } func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string, handshake bool) (outConn net.Conn, err interface{}) { switch *s.cfg.ParentType { case "kcp": fallthrough case "tls": fallthrough case "tcp": if *s.cfg.ParentType == "tls" { var _outConn tls.Conn _outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) outConn = net.Conn(&_outConn) } else if *s.cfg.ParentType == "kcp" { outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), s.cfg.KCP) } else { outConn, err = utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout) } if err != nil { err = fmt.Errorf("connect fail,%s", err) return } if *s.cfg.ParentCompress { outConn = utils.NewCompConn(outConn) } if *s.cfg.ParentKey != "" { outConn = conncrypt.New(outConn, &conncrypt.Config{ Password: *s.cfg.ParentKey, }) } if !handshake { return } var buf = make([]byte, 1024) //var n int outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) _, err = outConn.Write(methodBytes) outConn.SetDeadline(time.Time{}) if err != nil { err = fmt.Errorf("write method fail,%s", err) return } outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) _, err = outConn.Read(buf) outConn.SetDeadline(time.Time{}) if err != nil { err = fmt.Errorf("read method reply fail,%s", err) return } //resp := buf[:n] //s.log.Printf("resp:%v", resp) outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) _, err = outConn.Write(reqBytes) outConn.SetDeadline(time.Time{}) if err != nil { err = fmt.Errorf("write req detail fail,%s", err) return } outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) _, err = outConn.Read(buf) outConn.SetDeadline(time.Time{}) if err != nil { err = fmt.Errorf("read req reply fail,%s", err) return } //result := buf[:n] //s.log.Printf("result:%v", result) case "ssh": maxTryCount := 1 tryCount := 0 RETRY: if tryCount >= maxTryCount || s.isStop { return } wait := make(chan bool, 1) go func() { defer func() { if err == nil { err = recover() } wait <- true }() outConn, err = s.sshClient.Dial("tcp", host) }() select { case <-wait: case <-time.After(time.Millisecond * time.Duration(*s.cfg.Timeout) * 2): err = fmt.Errorf("ssh dial %s timeout", host) s.sshClient.Close() } if err != nil { s.log.Printf("connect ssh fail, ERR: %s, retrying...", err) e := s.ConnectSSH() if e == nil { tryCount++ time.Sleep(time.Second * 3) goto RETRY } else { err = e } } } return } func (s *Socks) ConnectSSH() (err error) { select { case s.lockChn <- true: default: err = fmt.Errorf("can not connect at same time") return } config := ssh.ClientConfig{ Timeout: time.Duration(*s.cfg.Timeout) * time.Millisecond, User: *s.cfg.SSHUser, Auth: []ssh.AuthMethod{s.cfg.SSHAuthMethod}, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, } if s.sshClient != nil { s.sshClient.Close() } s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config) <-s.lockChn return } func (s *Socks) InitBasicAuth() (err error) { if *s.cfg.DNSAddress != "" { s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver, s.log) } else { s.basicAuth = utils.NewBasicAuth(nil, s.log) } if *s.cfg.AuthURL != "" { s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) s.log.Printf("auth from %s", *s.cfg.AuthURL) } if *s.cfg.AuthFile != "" { var n = 0 n, err = s.basicAuth.AddFromFile(*s.cfg.AuthFile) if err != nil { err = fmt.Errorf("auth-file ERR:%s", err) return } s.log.Printf("auth data added from file %d , total:%d", n, s.basicAuth.Total()) } if len(*s.cfg.Auth) > 0 { n := s.basicAuth.Add(*s.cfg.Auth) s.log.Printf("auth data added %d, total:%d", n, s.basicAuth.Total()) } return } func (s *Socks) IsBasicAuth() bool { return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 || *s.cfg.AuthURL != "" } func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool { inIP, inPort, err := net.SplitHostPort(inLocalAddr) if err != nil { return false } outDomain, outPort, err := net.SplitHostPort(host) if err != nil { return false } if inPort == outPort { var outIPs []net.IP if *s.cfg.DNSAddress != "" { outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))} } else { outIPs, err = net.LookupIP(outDomain) } if err == nil { for _, ip := range outIPs { if ip.String() == inIP { return true } } } interfaceIPs, err := utils.GetAllInterfaceAddr() for _, ip := range *s.cfg.LocalIPS { interfaceIPs = append(interfaceIPs, net.ParseIP(ip).To4()) } if err == nil { for _, localIP := range interfaceIPs { for _, outIP := range outIPs { if localIP.Equal(outIP) { return true } } } } } return false } func (s *Socks) Resolve(address string) string { if *s.cfg.DNSAddress == "" { return address } ip, err := s.domainResolver.Resolve(address) if err != nil { s.log.Printf("dns error %s , ERR:%s", address, err) return address } return ip }