add socks dns support

This commit is contained in:
arraykeys@gmail.com
2018-02-27 18:33:22 +08:00
parent 3726f5b9c3
commit 7599e2c793
4 changed files with 72 additions and 33 deletions

View File

@ -37,7 +37,7 @@ func NewHTTP() Service {
func (s *HTTP) CheckArgs() {
var err error
if *s.cfg.Parent != "" && *s.cfg.ParentType == "" {
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>")
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
}
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
@ -86,7 +86,7 @@ func (s *HTTP) InitService() {
go func() {
//循环检查ssh网络连通性
for {
conn, err := utils.ConnectHost(s.domainResolver.MustResolve(*s.cfg.Parent), *s.cfg.Timeout*2)
conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
if err == nil {
_, err = conn.Write([]byte{0})
}
@ -170,9 +170,10 @@ func (s *HTTP) callback(inConn net.Conn) {
} else if *s.cfg.Always {
useProxy = true
} else {
s.checker.Add(address)
k := s.Resolve(address)
s.checker.Add(k)
//var n, m uint
useProxy, _, _ = s.checker.IsBlocked(req.Host)
useProxy, _, _ = s.checker.IsBlocked(k)
//log.Printf("blocked ? : %v, %s , fail:%d ,success:%d", useProxy, address, n, m)
}
}
@ -215,7 +216,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
}
}
} else {
outConn, err = utils.ConnectHost(s.domainResolver.MustResolve(address), *s.cfg.Timeout)
outConn, err = utils.ConnectHost(s.Resolve(address), *s.cfg.Timeout)
}
tryCount++
if err == nil || tryCount > maxTryCount {
@ -308,7 +309,7 @@ func (s *HTTP) ConnectSSH() (err error) {
if s.sshClient != nil {
s.sshClient.Close()
}
s.sshClient, err = ssh.Dial("tcp", s.domainResolver.MustResolve(*s.cfg.Parent), &config)
s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
<-s.lockChn
return
}
@ -322,7 +323,7 @@ func (s *HTTP) InitOutConnPool() {
*s.cfg.KCPMethod,
*s.cfg.KCPKey,
s.cfg.CertBytes, s.cfg.KeyBytes,
s.domainResolver.MustResolve(*s.cfg.Parent),
s.Resolve(*s.cfg.Parent),
*s.cfg.Timeout,
*s.cfg.PoolSize,
*s.cfg.PoolSize*2,
@ -330,7 +331,11 @@ func (s *HTTP) InitOutConnPool() {
}
}
func (s *HTTP) InitBasicAuth() (err error) {
s.basicAuth = utils.NewBasicAuth()
if *s.cfg.DNSAddress != "" {
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
} else {
s.basicAuth = utils.NewBasicAuth(nil)
}
if *s.cfg.AuthURL != "" {
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
log.Printf("auth from %s", *s.cfg.AuthURL)
@ -364,7 +369,11 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
}
if inPort == outPort {
var outIPs []net.IP
if *s.cfg.DNSAddress != "" {
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
} else {
outIPs, err = net.LookupIP(outDomain)
}
if err == nil {
for _, ip := range outIPs {
if ip.String() == inIP {
@ -388,3 +397,9 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
}
return false
}
func (s *HTTP) Resolve(address string) string {
if *s.cfg.DNSAddress == "" {
return address
}
return s.domainResolver.MustResolve(address)
}

View File

@ -42,7 +42,7 @@ func (s *Socks) CheckArgs() {
}
if *s.cfg.Parent != "" {
if *s.cfg.ParentType == "" {
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>")
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
}
if *s.cfg.ParentType == "tls" {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
@ -90,7 +90,7 @@ func (s *Socks) InitService() {
go func() {
//循环检查ssh网络连通性
for {
conn, err := utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout*2)
conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
if err == nil {
_, err = conn.Write([]byte{0})
}
@ -110,7 +110,6 @@ func (s *Socks) InitService() {
if *s.cfg.ParentType == "ssh" {
log.Println("warn: socks udp not suppored for ssh")
} else {
s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal)
err := s.udpSC.ListenUDP(s.udpCallback)
if err != nil {
@ -192,7 +191,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
if parent == "" {
parent = *s.cfg.Parent
}
dstAddr, err := net.ResolveUDPAddr("udp", parent)
dstAddr, err := net.ResolveUDPAddr("udp", s.Resolve(parent))
if err != nil {
log.Printf("can't resolve address: %s", err)
return
@ -248,7 +247,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
} else {
//本地代理
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port()))
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(s.Resolve(p.Host()), p.Port()))
if err != nil {
log.Printf("can't resolve address: %s", err)
return
@ -425,16 +424,17 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
if utils.IsIternalIP(host) {
useProxy = false
} else {
s.checker.Add(request.Addr())
useProxy, _, _ = s.checker.IsBlocked(request.Addr())
k := s.Resolve(request.Addr())
s.checker.Add(k)
useProxy, _, _ = s.checker.IsBlocked(k)
}
if useProxy {
outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr())
} else {
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
}
} else {
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
useProxy = false
}
}
@ -471,12 +471,12 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n
case "tcp":
if *s.cfg.ParentType == "tls" {
var _outConn tls.Conn
_outConn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
_outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
outConn = net.Conn(&_outConn)
} else if *s.cfg.ParentType == "kcp" {
outConn, err = utils.ConnectKCPHost(*s.cfg.Parent, *s.cfg.KCPMethod, *s.cfg.KCPKey)
outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), *s.cfg.KCPMethod, *s.cfg.KCPKey)
} else {
outConn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout)
outConn, err = utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout)
}
if err != nil {
err = fmt.Errorf("connect fail,%s", err)
@ -566,12 +566,16 @@ func (s *Socks) ConnectSSH() (err error) {
if s.sshClient != nil {
s.sshClient.Close()
}
s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config)
s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
<-s.lockChn
return
}
func (s *Socks) InitBasicAuth() (err error) {
s.basicAuth = utils.NewBasicAuth()
if *s.cfg.DNSAddress != "" {
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
} else {
s.basicAuth = utils.NewBasicAuth(nil)
}
if *s.cfg.AuthURL != "" {
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
log.Printf("auth from %s", *s.cfg.AuthURL)
@ -605,7 +609,11 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
}
if inPort == outPort {
var outIPs []net.IP
if *s.cfg.DNSAddress != "" {
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
} else {
outIPs, err = net.LookupIP(outDomain)
}
if err == nil {
for _, ip := range outIPs {
if ip.String() == inIP {
@ -629,3 +637,9 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
}
return false
}
func (s *Socks) Resolve(address string) string {
if *s.cfg.DNSAddress == "" {
return address
}
return s.domainResolver.MustResolve(address)
}

View File

@ -451,7 +451,7 @@ func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, er
if err != nil {
return
}
if len(host) == 1 {
if len(host) == 1 && host[0] != "" {
req.Host = host[0]
}
resp, err := client.Do(req)

View File

@ -173,11 +173,13 @@ type BasicAuth struct {
authOkCode int
authTimeout int
authRetry int
dns *DomainResolver
}
func NewBasicAuth() BasicAuth {
func NewBasicAuth(dns *DomainResolver) BasicAuth {
return BasicAuth{
data: NewConcurrentMap(),
dns: dns,
}
}
func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) {
@ -241,18 +243,27 @@ func (ba *BasicAuth) checkFromURL(userpass, ip, target string) (err error) {
if len(u) != 2 {
return
}
URL := ba.authURL
if strings.Contains(URL, "?") {
URL += "&"
} else {
URL += "?"
}
URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, target)
URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, url.QueryEscape(target))
getURL := URL
var domain string
if ba.dns != nil {
_url, _ := url.Parse(ba.authURL)
domain = _url.Host
domainIP := ba.dns.MustResolve(domain)
getURL = strings.Replace(URL, domain, domainIP, 1)
}
var code int
var tryCount = 0
var body []byte
for tryCount <= ba.authRetry {
body, code, err = HttpGet(URL, ba.authTimeout)
body, code, err = HttpGet(getURL, ba.authTimeout, domain)
if err == nil && code == ba.authOkCode {
break
} else if err != nil {
@ -855,6 +866,10 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
return
}
}
if net.ParseIP(domain) != nil {
ip = domain
return
}
item, ok := a.data.Get(domain)
if ok {
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
@ -868,10 +883,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
}
a.data.Set(domain, item)
}
if net.ParseIP(domain) != nil {
return domain, nil
}
c := new(dns.Client)
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
@ -880,7 +891,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
if r == nil {
return
}
if r.Rcode != dns.RcodeSuccess {
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
return