From 7599e2c79377feb480f9c728ee14c54d4305d44e Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Tue, 27 Feb 2018 18:33:22 +0800 Subject: [PATCH] add socks dns support --- services/http.go | 33 ++++++++++++++++++++++++--------- services/socks.go | 44 +++++++++++++++++++++++++++++--------------- utils/functions.go | 2 +- utils/structs.go | 26 ++++++++++++++++++-------- 4 files changed, 72 insertions(+), 33 deletions(-) diff --git a/services/http.go b/services/http.go index cd3f0cc..1eeeb9d 100644 --- a/services/http.go +++ b/services/http.go @@ -37,7 +37,7 @@ func NewHTTP() Service { func (s *HTTP) CheckArgs() { var err error if *s.cfg.Parent != "" && *s.cfg.ParentType == "" { - log.Fatalf("parent type unkown,use -T ") + log.Fatalf("parent type unkown,use -T ") } if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" { s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) @@ -86,7 +86,7 @@ func (s *HTTP) InitService() { go func() { //循环检查ssh网络连通性 for { - conn, err := utils.ConnectHost(s.domainResolver.MustResolve(*s.cfg.Parent), *s.cfg.Timeout*2) + conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2) if err == nil { _, err = conn.Write([]byte{0}) } @@ -170,9 +170,10 @@ func (s *HTTP) callback(inConn net.Conn) { } else if *s.cfg.Always { useProxy = true } else { - s.checker.Add(address) + k := s.Resolve(address) + s.checker.Add(k) //var n, m uint - useProxy, _, _ = s.checker.IsBlocked(req.Host) + useProxy, _, _ = s.checker.IsBlocked(k) //log.Printf("blocked ? : %v, %s , fail:%d ,success:%d", useProxy, address, n, m) } } @@ -215,7 +216,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut } } } else { - outConn, err = utils.ConnectHost(s.domainResolver.MustResolve(address), *s.cfg.Timeout) + outConn, err = utils.ConnectHost(s.Resolve(address), *s.cfg.Timeout) } tryCount++ if err == nil || tryCount > maxTryCount { @@ -308,7 +309,7 @@ func (s *HTTP) ConnectSSH() (err error) { if s.sshClient != nil { s.sshClient.Close() } - s.sshClient, err = ssh.Dial("tcp", s.domainResolver.MustResolve(*s.cfg.Parent), &config) + s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config) <-s.lockChn return } @@ -322,7 +323,7 @@ func (s *HTTP) InitOutConnPool() { *s.cfg.KCPMethod, *s.cfg.KCPKey, s.cfg.CertBytes, s.cfg.KeyBytes, - s.domainResolver.MustResolve(*s.cfg.Parent), + s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, *s.cfg.PoolSize, *s.cfg.PoolSize*2, @@ -330,7 +331,11 @@ func (s *HTTP) InitOutConnPool() { } } func (s *HTTP) InitBasicAuth() (err error) { - s.basicAuth = utils.NewBasicAuth() + if *s.cfg.DNSAddress != "" { + s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver) + } else { + s.basicAuth = utils.NewBasicAuth(nil) + } if *s.cfg.AuthURL != "" { s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) log.Printf("auth from %s", *s.cfg.AuthURL) @@ -364,7 +369,11 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool { } if inPort == outPort { var outIPs []net.IP - outIPs, err = net.LookupIP(outDomain) + if *s.cfg.DNSAddress != "" { + outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))} + } else { + outIPs, err = net.LookupIP(outDomain) + } if err == nil { for _, ip := range outIPs { if ip.String() == inIP { @@ -388,3 +397,9 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool { } return false } +func (s *HTTP) Resolve(address string) string { + if *s.cfg.DNSAddress == "" { + return address + } + return s.domainResolver.MustResolve(address) +} diff --git a/services/socks.go b/services/socks.go index e1d13ed..377ce74 100644 --- a/services/socks.go +++ b/services/socks.go @@ -42,7 +42,7 @@ func (s *Socks) CheckArgs() { } if *s.cfg.Parent != "" { if *s.cfg.ParentType == "" { - log.Fatalf("parent type unkown,use -T ") + log.Fatalf("parent type unkown,use -T ") } if *s.cfg.ParentType == "tls" { s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) @@ -90,7 +90,7 @@ func (s *Socks) InitService() { go func() { //循环检查ssh网络连通性 for { - conn, err := utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout*2) + conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2) if err == nil { _, err = conn.Write([]byte{0}) } @@ -110,7 +110,6 @@ func (s *Socks) InitService() { if *s.cfg.ParentType == "ssh" { log.Println("warn: socks udp not suppored for ssh") } else { - s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal) err := s.udpSC.ListenUDP(s.udpCallback) if err != nil { @@ -192,7 +191,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) { if parent == "" { parent = *s.cfg.Parent } - dstAddr, err := net.ResolveUDPAddr("udp", parent) + dstAddr, err := net.ResolveUDPAddr("udp", s.Resolve(parent)) if err != nil { log.Printf("can't resolve address: %s", err) return @@ -248,7 +247,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) { } else { //本地代理 - dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port())) + dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(s.Resolve(p.Host()), p.Port())) if err != nil { log.Printf("can't resolve address: %s", err) return @@ -425,16 +424,17 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque if utils.IsIternalIP(host) { useProxy = false } else { - s.checker.Add(request.Addr()) - useProxy, _, _ = s.checker.IsBlocked(request.Addr()) + k := s.Resolve(request.Addr()) + s.checker.Add(k) + useProxy, _, _ = s.checker.IsBlocked(k) } if useProxy { outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr()) } else { - outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) + outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout) } } else { - outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) + outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout) useProxy = false } } @@ -471,12 +471,12 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n case "tcp": if *s.cfg.ParentType == "tls" { var _outConn tls.Conn - _outConn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + _outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) outConn = net.Conn(&_outConn) } else if *s.cfg.ParentType == "kcp" { - outConn, err = utils.ConnectKCPHost(*s.cfg.Parent, *s.cfg.KCPMethod, *s.cfg.KCPKey) + outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), *s.cfg.KCPMethod, *s.cfg.KCPKey) } else { - outConn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout) + outConn, err = utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout) } if err != nil { err = fmt.Errorf("connect fail,%s", err) @@ -566,12 +566,16 @@ func (s *Socks) ConnectSSH() (err error) { if s.sshClient != nil { s.sshClient.Close() } - s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config) + s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config) <-s.lockChn return } func (s *Socks) InitBasicAuth() (err error) { - s.basicAuth = utils.NewBasicAuth() + if *s.cfg.DNSAddress != "" { + s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver) + } else { + s.basicAuth = utils.NewBasicAuth(nil) + } if *s.cfg.AuthURL != "" { s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) log.Printf("auth from %s", *s.cfg.AuthURL) @@ -605,7 +609,11 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool { } if inPort == outPort { var outIPs []net.IP - outIPs, err = net.LookupIP(outDomain) + if *s.cfg.DNSAddress != "" { + outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))} + } else { + outIPs, err = net.LookupIP(outDomain) + } if err == nil { for _, ip := range outIPs { if ip.String() == inIP { @@ -629,3 +637,9 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool { } return false } +func (s *Socks) Resolve(address string) string { + if *s.cfg.DNSAddress == "" { + return address + } + return s.domainResolver.MustResolve(address) +} diff --git a/utils/functions.go b/utils/functions.go index 80a16ab..873878c 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -451,7 +451,7 @@ func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, er if err != nil { return } - if len(host) == 1 { + if len(host) == 1 && host[0] != "" { req.Host = host[0] } resp, err := client.Do(req) diff --git a/utils/structs.go b/utils/structs.go index a2c5baa..6a40a20 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -173,11 +173,13 @@ type BasicAuth struct { authOkCode int authTimeout int authRetry int + dns *DomainResolver } -func NewBasicAuth() BasicAuth { +func NewBasicAuth(dns *DomainResolver) BasicAuth { return BasicAuth{ data: NewConcurrentMap(), + dns: dns, } } func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) { @@ -241,18 +243,27 @@ func (ba *BasicAuth) checkFromURL(userpass, ip, target string) (err error) { if len(u) != 2 { return } + URL := ba.authURL if strings.Contains(URL, "?") { URL += "&" } else { URL += "?" } - URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, target) + URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, url.QueryEscape(target)) + getURL := URL + var domain string + if ba.dns != nil { + _url, _ := url.Parse(ba.authURL) + domain = _url.Host + domainIP := ba.dns.MustResolve(domain) + getURL = strings.Replace(URL, domain, domainIP, 1) + } var code int var tryCount = 0 var body []byte for tryCount <= ba.authRetry { - body, code, err = HttpGet(URL, ba.authTimeout) + body, code, err = HttpGet(getURL, ba.authTimeout, domain) if err == nil && code == ba.authOkCode { break } else if err != nil { @@ -855,6 +866,10 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) { return } } + if net.ParseIP(domain) != nil { + ip = domain + return + } item, ok := a.data.Get(domain) if ok { if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() { @@ -868,10 +883,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) { } a.data.Set(domain, item) } - - if net.ParseIP(domain) != nil { - return domain, nil - } c := new(dns.Client) m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) @@ -880,7 +891,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) { if r == nil { return } - if r.Rcode != dns.RcodeSuccess { err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress) return