add socks dns support
This commit is contained in:
@ -37,7 +37,7 @@ func NewHTTP() Service {
|
||||
func (s *HTTP) CheckArgs() {
|
||||
var err error
|
||||
if *s.cfg.Parent != "" && *s.cfg.ParentType == "" {
|
||||
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>")
|
||||
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
|
||||
}
|
||||
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" {
|
||||
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
||||
@ -86,7 +86,7 @@ func (s *HTTP) InitService() {
|
||||
go func() {
|
||||
//循环检查ssh网络连通性
|
||||
for {
|
||||
conn, err := utils.ConnectHost(s.domainResolver.MustResolve(*s.cfg.Parent), *s.cfg.Timeout*2)
|
||||
conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
|
||||
if err == nil {
|
||||
_, err = conn.Write([]byte{0})
|
||||
}
|
||||
@ -170,9 +170,10 @@ func (s *HTTP) callback(inConn net.Conn) {
|
||||
} else if *s.cfg.Always {
|
||||
useProxy = true
|
||||
} else {
|
||||
s.checker.Add(address)
|
||||
k := s.Resolve(address)
|
||||
s.checker.Add(k)
|
||||
//var n, m uint
|
||||
useProxy, _, _ = s.checker.IsBlocked(req.Host)
|
||||
useProxy, _, _ = s.checker.IsBlocked(k)
|
||||
//log.Printf("blocked ? : %v, %s , fail:%d ,success:%d", useProxy, address, n, m)
|
||||
}
|
||||
}
|
||||
@ -215,7 +216,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
|
||||
}
|
||||
}
|
||||
} else {
|
||||
outConn, err = utils.ConnectHost(s.domainResolver.MustResolve(address), *s.cfg.Timeout)
|
||||
outConn, err = utils.ConnectHost(s.Resolve(address), *s.cfg.Timeout)
|
||||
}
|
||||
tryCount++
|
||||
if err == nil || tryCount > maxTryCount {
|
||||
@ -308,7 +309,7 @@ func (s *HTTP) ConnectSSH() (err error) {
|
||||
if s.sshClient != nil {
|
||||
s.sshClient.Close()
|
||||
}
|
||||
s.sshClient, err = ssh.Dial("tcp", s.domainResolver.MustResolve(*s.cfg.Parent), &config)
|
||||
s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
|
||||
<-s.lockChn
|
||||
return
|
||||
}
|
||||
@ -322,7 +323,7 @@ func (s *HTTP) InitOutConnPool() {
|
||||
*s.cfg.KCPMethod,
|
||||
*s.cfg.KCPKey,
|
||||
s.cfg.CertBytes, s.cfg.KeyBytes,
|
||||
s.domainResolver.MustResolve(*s.cfg.Parent),
|
||||
s.Resolve(*s.cfg.Parent),
|
||||
*s.cfg.Timeout,
|
||||
*s.cfg.PoolSize,
|
||||
*s.cfg.PoolSize*2,
|
||||
@ -330,7 +331,11 @@ func (s *HTTP) InitOutConnPool() {
|
||||
}
|
||||
}
|
||||
func (s *HTTP) InitBasicAuth() (err error) {
|
||||
s.basicAuth = utils.NewBasicAuth()
|
||||
if *s.cfg.DNSAddress != "" {
|
||||
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
|
||||
} else {
|
||||
s.basicAuth = utils.NewBasicAuth(nil)
|
||||
}
|
||||
if *s.cfg.AuthURL != "" {
|
||||
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
|
||||
log.Printf("auth from %s", *s.cfg.AuthURL)
|
||||
@ -364,7 +369,11 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
|
||||
}
|
||||
if inPort == outPort {
|
||||
var outIPs []net.IP
|
||||
outIPs, err = net.LookupIP(outDomain)
|
||||
if *s.cfg.DNSAddress != "" {
|
||||
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
|
||||
} else {
|
||||
outIPs, err = net.LookupIP(outDomain)
|
||||
}
|
||||
if err == nil {
|
||||
for _, ip := range outIPs {
|
||||
if ip.String() == inIP {
|
||||
@ -388,3 +397,9 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (s *HTTP) Resolve(address string) string {
|
||||
if *s.cfg.DNSAddress == "" {
|
||||
return address
|
||||
}
|
||||
return s.domainResolver.MustResolve(address)
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ func (s *Socks) CheckArgs() {
|
||||
}
|
||||
if *s.cfg.Parent != "" {
|
||||
if *s.cfg.ParentType == "" {
|
||||
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>")
|
||||
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
|
||||
}
|
||||
if *s.cfg.ParentType == "tls" {
|
||||
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
||||
@ -90,7 +90,7 @@ func (s *Socks) InitService() {
|
||||
go func() {
|
||||
//循环检查ssh网络连通性
|
||||
for {
|
||||
conn, err := utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout*2)
|
||||
conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
|
||||
if err == nil {
|
||||
_, err = conn.Write([]byte{0})
|
||||
}
|
||||
@ -110,7 +110,6 @@ func (s *Socks) InitService() {
|
||||
if *s.cfg.ParentType == "ssh" {
|
||||
log.Println("warn: socks udp not suppored for ssh")
|
||||
} else {
|
||||
|
||||
s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal)
|
||||
err := s.udpSC.ListenUDP(s.udpCallback)
|
||||
if err != nil {
|
||||
@ -192,7 +191,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
|
||||
if parent == "" {
|
||||
parent = *s.cfg.Parent
|
||||
}
|
||||
dstAddr, err := net.ResolveUDPAddr("udp", parent)
|
||||
dstAddr, err := net.ResolveUDPAddr("udp", s.Resolve(parent))
|
||||
if err != nil {
|
||||
log.Printf("can't resolve address: %s", err)
|
||||
return
|
||||
@ -248,7 +247,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
|
||||
|
||||
} else {
|
||||
//本地代理
|
||||
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port()))
|
||||
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(s.Resolve(p.Host()), p.Port()))
|
||||
if err != nil {
|
||||
log.Printf("can't resolve address: %s", err)
|
||||
return
|
||||
@ -425,16 +424,17 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
|
||||
if utils.IsIternalIP(host) {
|
||||
useProxy = false
|
||||
} else {
|
||||
s.checker.Add(request.Addr())
|
||||
useProxy, _, _ = s.checker.IsBlocked(request.Addr())
|
||||
k := s.Resolve(request.Addr())
|
||||
s.checker.Add(k)
|
||||
useProxy, _, _ = s.checker.IsBlocked(k)
|
||||
}
|
||||
if useProxy {
|
||||
outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr())
|
||||
} else {
|
||||
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
|
||||
outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
|
||||
}
|
||||
} else {
|
||||
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
|
||||
outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
|
||||
useProxy = false
|
||||
}
|
||||
}
|
||||
@ -471,12 +471,12 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n
|
||||
case "tcp":
|
||||
if *s.cfg.ParentType == "tls" {
|
||||
var _outConn tls.Conn
|
||||
_outConn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
|
||||
_outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
|
||||
outConn = net.Conn(&_outConn)
|
||||
} else if *s.cfg.ParentType == "kcp" {
|
||||
outConn, err = utils.ConnectKCPHost(*s.cfg.Parent, *s.cfg.KCPMethod, *s.cfg.KCPKey)
|
||||
outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), *s.cfg.KCPMethod, *s.cfg.KCPKey)
|
||||
} else {
|
||||
outConn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout)
|
||||
outConn, err = utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout)
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("connect fail,%s", err)
|
||||
@ -566,12 +566,16 @@ func (s *Socks) ConnectSSH() (err error) {
|
||||
if s.sshClient != nil {
|
||||
s.sshClient.Close()
|
||||
}
|
||||
s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config)
|
||||
s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
|
||||
<-s.lockChn
|
||||
return
|
||||
}
|
||||
func (s *Socks) InitBasicAuth() (err error) {
|
||||
s.basicAuth = utils.NewBasicAuth()
|
||||
if *s.cfg.DNSAddress != "" {
|
||||
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
|
||||
} else {
|
||||
s.basicAuth = utils.NewBasicAuth(nil)
|
||||
}
|
||||
if *s.cfg.AuthURL != "" {
|
||||
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
|
||||
log.Printf("auth from %s", *s.cfg.AuthURL)
|
||||
@ -605,7 +609,11 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
|
||||
}
|
||||
if inPort == outPort {
|
||||
var outIPs []net.IP
|
||||
outIPs, err = net.LookupIP(outDomain)
|
||||
if *s.cfg.DNSAddress != "" {
|
||||
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
|
||||
} else {
|
||||
outIPs, err = net.LookupIP(outDomain)
|
||||
}
|
||||
if err == nil {
|
||||
for _, ip := range outIPs {
|
||||
if ip.String() == inIP {
|
||||
@ -629,3 +637,9 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (s *Socks) Resolve(address string) string {
|
||||
if *s.cfg.DNSAddress == "" {
|
||||
return address
|
||||
}
|
||||
return s.domainResolver.MustResolve(address)
|
||||
}
|
||||
|
||||
@ -451,7 +451,7 @@ func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, er
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(host) == 1 {
|
||||
if len(host) == 1 && host[0] != "" {
|
||||
req.Host = host[0]
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
|
||||
@ -173,11 +173,13 @@ type BasicAuth struct {
|
||||
authOkCode int
|
||||
authTimeout int
|
||||
authRetry int
|
||||
dns *DomainResolver
|
||||
}
|
||||
|
||||
func NewBasicAuth() BasicAuth {
|
||||
func NewBasicAuth(dns *DomainResolver) BasicAuth {
|
||||
return BasicAuth{
|
||||
data: NewConcurrentMap(),
|
||||
dns: dns,
|
||||
}
|
||||
}
|
||||
func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) {
|
||||
@ -241,18 +243,27 @@ func (ba *BasicAuth) checkFromURL(userpass, ip, target string) (err error) {
|
||||
if len(u) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
URL := ba.authURL
|
||||
if strings.Contains(URL, "?") {
|
||||
URL += "&"
|
||||
} else {
|
||||
URL += "?"
|
||||
}
|
||||
URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, target)
|
||||
URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, url.QueryEscape(target))
|
||||
getURL := URL
|
||||
var domain string
|
||||
if ba.dns != nil {
|
||||
_url, _ := url.Parse(ba.authURL)
|
||||
domain = _url.Host
|
||||
domainIP := ba.dns.MustResolve(domain)
|
||||
getURL = strings.Replace(URL, domain, domainIP, 1)
|
||||
}
|
||||
var code int
|
||||
var tryCount = 0
|
||||
var body []byte
|
||||
for tryCount <= ba.authRetry {
|
||||
body, code, err = HttpGet(URL, ba.authTimeout)
|
||||
body, code, err = HttpGet(getURL, ba.authTimeout, domain)
|
||||
if err == nil && code == ba.authOkCode {
|
||||
break
|
||||
} else if err != nil {
|
||||
@ -855,6 +866,10 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if net.ParseIP(domain) != nil {
|
||||
ip = domain
|
||||
return
|
||||
}
|
||||
item, ok := a.data.Get(domain)
|
||||
if ok {
|
||||
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
|
||||
@ -868,10 +883,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
||||
}
|
||||
a.data.Set(domain, item)
|
||||
}
|
||||
|
||||
if net.ParseIP(domain) != nil {
|
||||
return domain, nil
|
||||
}
|
||||
c := new(dns.Client)
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||
@ -880,7 +891,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user