add socks dns support

This commit is contained in:
arraykeys@gmail.com
2018-02-27 18:33:22 +08:00
parent 3726f5b9c3
commit 7599e2c793
4 changed files with 72 additions and 33 deletions

View File

@ -37,7 +37,7 @@ func NewHTTP() Service {
func (s *HTTP) CheckArgs() { func (s *HTTP) CheckArgs() {
var err error var err error
if *s.cfg.Parent != "" && *s.cfg.ParentType == "" { if *s.cfg.Parent != "" && *s.cfg.ParentType == "" {
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>") log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
} }
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" { if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
@ -86,7 +86,7 @@ func (s *HTTP) InitService() {
go func() { go func() {
//循环检查ssh网络连通性 //循环检查ssh网络连通性
for { for {
conn, err := utils.ConnectHost(s.domainResolver.MustResolve(*s.cfg.Parent), *s.cfg.Timeout*2) conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
if err == nil { if err == nil {
_, err = conn.Write([]byte{0}) _, err = conn.Write([]byte{0})
} }
@ -170,9 +170,10 @@ func (s *HTTP) callback(inConn net.Conn) {
} else if *s.cfg.Always { } else if *s.cfg.Always {
useProxy = true useProxy = true
} else { } else {
s.checker.Add(address) k := s.Resolve(address)
s.checker.Add(k)
//var n, m uint //var n, m uint
useProxy, _, _ = s.checker.IsBlocked(req.Host) useProxy, _, _ = s.checker.IsBlocked(k)
//log.Printf("blocked ? : %v, %s , fail:%d ,success:%d", useProxy, address, n, m) //log.Printf("blocked ? : %v, %s , fail:%d ,success:%d", useProxy, address, n, m)
} }
} }
@ -215,7 +216,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
} }
} }
} else { } else {
outConn, err = utils.ConnectHost(s.domainResolver.MustResolve(address), *s.cfg.Timeout) outConn, err = utils.ConnectHost(s.Resolve(address), *s.cfg.Timeout)
} }
tryCount++ tryCount++
if err == nil || tryCount > maxTryCount { if err == nil || tryCount > maxTryCount {
@ -308,7 +309,7 @@ func (s *HTTP) ConnectSSH() (err error) {
if s.sshClient != nil { if s.sshClient != nil {
s.sshClient.Close() s.sshClient.Close()
} }
s.sshClient, err = ssh.Dial("tcp", s.domainResolver.MustResolve(*s.cfg.Parent), &config) s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
<-s.lockChn <-s.lockChn
return return
} }
@ -322,7 +323,7 @@ func (s *HTTP) InitOutConnPool() {
*s.cfg.KCPMethod, *s.cfg.KCPMethod,
*s.cfg.KCPKey, *s.cfg.KCPKey,
s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CertBytes, s.cfg.KeyBytes,
s.domainResolver.MustResolve(*s.cfg.Parent), s.Resolve(*s.cfg.Parent),
*s.cfg.Timeout, *s.cfg.Timeout,
*s.cfg.PoolSize, *s.cfg.PoolSize,
*s.cfg.PoolSize*2, *s.cfg.PoolSize*2,
@ -330,7 +331,11 @@ func (s *HTTP) InitOutConnPool() {
} }
} }
func (s *HTTP) InitBasicAuth() (err error) { func (s *HTTP) InitBasicAuth() (err error) {
s.basicAuth = utils.NewBasicAuth() if *s.cfg.DNSAddress != "" {
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
} else {
s.basicAuth = utils.NewBasicAuth(nil)
}
if *s.cfg.AuthURL != "" { if *s.cfg.AuthURL != "" {
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
log.Printf("auth from %s", *s.cfg.AuthURL) log.Printf("auth from %s", *s.cfg.AuthURL)
@ -364,7 +369,11 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
} }
if inPort == outPort { if inPort == outPort {
var outIPs []net.IP var outIPs []net.IP
outIPs, err = net.LookupIP(outDomain) if *s.cfg.DNSAddress != "" {
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
} else {
outIPs, err = net.LookupIP(outDomain)
}
if err == nil { if err == nil {
for _, ip := range outIPs { for _, ip := range outIPs {
if ip.String() == inIP { if ip.String() == inIP {
@ -388,3 +397,9 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
} }
return false return false
} }
func (s *HTTP) Resolve(address string) string {
if *s.cfg.DNSAddress == "" {
return address
}
return s.domainResolver.MustResolve(address)
}

View File

@ -42,7 +42,7 @@ func (s *Socks) CheckArgs() {
} }
if *s.cfg.Parent != "" { if *s.cfg.Parent != "" {
if *s.cfg.ParentType == "" { if *s.cfg.ParentType == "" {
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>") log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
} }
if *s.cfg.ParentType == "tls" { if *s.cfg.ParentType == "tls" {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
@ -90,7 +90,7 @@ func (s *Socks) InitService() {
go func() { go func() {
//循环检查ssh网络连通性 //循环检查ssh网络连通性
for { for {
conn, err := utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout*2) conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
if err == nil { if err == nil {
_, err = conn.Write([]byte{0}) _, err = conn.Write([]byte{0})
} }
@ -110,7 +110,6 @@ func (s *Socks) InitService() {
if *s.cfg.ParentType == "ssh" { if *s.cfg.ParentType == "ssh" {
log.Println("warn: socks udp not suppored for ssh") log.Println("warn: socks udp not suppored for ssh")
} else { } else {
s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal) s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal)
err := s.udpSC.ListenUDP(s.udpCallback) err := s.udpSC.ListenUDP(s.udpCallback)
if err != nil { if err != nil {
@ -192,7 +191,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
if parent == "" { if parent == "" {
parent = *s.cfg.Parent parent = *s.cfg.Parent
} }
dstAddr, err := net.ResolveUDPAddr("udp", parent) dstAddr, err := net.ResolveUDPAddr("udp", s.Resolve(parent))
if err != nil { if err != nil {
log.Printf("can't resolve address: %s", err) log.Printf("can't resolve address: %s", err)
return return
@ -248,7 +247,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
} else { } else {
//本地代理 //本地代理
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port())) dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(s.Resolve(p.Host()), p.Port()))
if err != nil { if err != nil {
log.Printf("can't resolve address: %s", err) log.Printf("can't resolve address: %s", err)
return return
@ -425,16 +424,17 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
if utils.IsIternalIP(host) { if utils.IsIternalIP(host) {
useProxy = false useProxy = false
} else { } else {
s.checker.Add(request.Addr()) k := s.Resolve(request.Addr())
useProxy, _, _ = s.checker.IsBlocked(request.Addr()) s.checker.Add(k)
useProxy, _, _ = s.checker.IsBlocked(k)
} }
if useProxy { if useProxy {
outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr()) outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr())
} else { } else {
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
} }
} else { } else {
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
useProxy = false useProxy = false
} }
} }
@ -471,12 +471,12 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n
case "tcp": case "tcp":
if *s.cfg.ParentType == "tls" { if *s.cfg.ParentType == "tls" {
var _outConn tls.Conn var _outConn tls.Conn
_outConn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) _outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
outConn = net.Conn(&_outConn) outConn = net.Conn(&_outConn)
} else if *s.cfg.ParentType == "kcp" { } else if *s.cfg.ParentType == "kcp" {
outConn, err = utils.ConnectKCPHost(*s.cfg.Parent, *s.cfg.KCPMethod, *s.cfg.KCPKey) outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), *s.cfg.KCPMethod, *s.cfg.KCPKey)
} else { } else {
outConn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout) outConn, err = utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout)
} }
if err != nil { if err != nil {
err = fmt.Errorf("connect fail,%s", err) err = fmt.Errorf("connect fail,%s", err)
@ -566,12 +566,16 @@ func (s *Socks) ConnectSSH() (err error) {
if s.sshClient != nil { if s.sshClient != nil {
s.sshClient.Close() s.sshClient.Close()
} }
s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config) s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
<-s.lockChn <-s.lockChn
return return
} }
func (s *Socks) InitBasicAuth() (err error) { func (s *Socks) InitBasicAuth() (err error) {
s.basicAuth = utils.NewBasicAuth() if *s.cfg.DNSAddress != "" {
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
} else {
s.basicAuth = utils.NewBasicAuth(nil)
}
if *s.cfg.AuthURL != "" { if *s.cfg.AuthURL != "" {
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
log.Printf("auth from %s", *s.cfg.AuthURL) log.Printf("auth from %s", *s.cfg.AuthURL)
@ -605,7 +609,11 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
} }
if inPort == outPort { if inPort == outPort {
var outIPs []net.IP var outIPs []net.IP
outIPs, err = net.LookupIP(outDomain) if *s.cfg.DNSAddress != "" {
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
} else {
outIPs, err = net.LookupIP(outDomain)
}
if err == nil { if err == nil {
for _, ip := range outIPs { for _, ip := range outIPs {
if ip.String() == inIP { if ip.String() == inIP {
@ -629,3 +637,9 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
} }
return false return false
} }
func (s *Socks) Resolve(address string) string {
if *s.cfg.DNSAddress == "" {
return address
}
return s.domainResolver.MustResolve(address)
}

View File

@ -451,7 +451,7 @@ func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, er
if err != nil { if err != nil {
return return
} }
if len(host) == 1 { if len(host) == 1 && host[0] != "" {
req.Host = host[0] req.Host = host[0]
} }
resp, err := client.Do(req) resp, err := client.Do(req)

View File

@ -173,11 +173,13 @@ type BasicAuth struct {
authOkCode int authOkCode int
authTimeout int authTimeout int
authRetry int authRetry int
dns *DomainResolver
} }
func NewBasicAuth() BasicAuth { func NewBasicAuth(dns *DomainResolver) BasicAuth {
return BasicAuth{ return BasicAuth{
data: NewConcurrentMap(), data: NewConcurrentMap(),
dns: dns,
} }
} }
func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) { func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) {
@ -241,18 +243,27 @@ func (ba *BasicAuth) checkFromURL(userpass, ip, target string) (err error) {
if len(u) != 2 { if len(u) != 2 {
return return
} }
URL := ba.authURL URL := ba.authURL
if strings.Contains(URL, "?") { if strings.Contains(URL, "?") {
URL += "&" URL += "&"
} else { } else {
URL += "?" URL += "?"
} }
URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, target) URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, url.QueryEscape(target))
getURL := URL
var domain string
if ba.dns != nil {
_url, _ := url.Parse(ba.authURL)
domain = _url.Host
domainIP := ba.dns.MustResolve(domain)
getURL = strings.Replace(URL, domain, domainIP, 1)
}
var code int var code int
var tryCount = 0 var tryCount = 0
var body []byte var body []byte
for tryCount <= ba.authRetry { for tryCount <= ba.authRetry {
body, code, err = HttpGet(URL, ba.authTimeout) body, code, err = HttpGet(getURL, ba.authTimeout, domain)
if err == nil && code == ba.authOkCode { if err == nil && code == ba.authOkCode {
break break
} else if err != nil { } else if err != nil {
@ -855,6 +866,10 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
return return
} }
} }
if net.ParseIP(domain) != nil {
ip = domain
return
}
item, ok := a.data.Get(domain) item, ok := a.data.Get(domain)
if ok { if ok {
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() { if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
@ -868,10 +883,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
} }
a.data.Set(domain, item) a.data.Set(domain, item)
} }
if net.ParseIP(domain) != nil {
return domain, nil
}
c := new(dns.Client) c := new(dns.Client)
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
@ -880,7 +891,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
if r == nil { if r == nil {
return return
} }
if r.Rcode != dns.RcodeSuccess { if r.Rcode != dns.RcodeSuccess {
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress) err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
return return