add socks dns support
This commit is contained in:
@ -37,7 +37,7 @@ func NewHTTP() Service {
|
|||||||
func (s *HTTP) CheckArgs() {
|
func (s *HTTP) CheckArgs() {
|
||||||
var err error
|
var err error
|
||||||
if *s.cfg.Parent != "" && *s.cfg.ParentType == "" {
|
if *s.cfg.Parent != "" && *s.cfg.ParentType == "" {
|
||||||
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>")
|
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
|
||||||
}
|
}
|
||||||
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" {
|
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" {
|
||||||
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
||||||
@ -86,7 +86,7 @@ func (s *HTTP) InitService() {
|
|||||||
go func() {
|
go func() {
|
||||||
//循环检查ssh网络连通性
|
//循环检查ssh网络连通性
|
||||||
for {
|
for {
|
||||||
conn, err := utils.ConnectHost(s.domainResolver.MustResolve(*s.cfg.Parent), *s.cfg.Timeout*2)
|
conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = conn.Write([]byte{0})
|
_, err = conn.Write([]byte{0})
|
||||||
}
|
}
|
||||||
@ -170,9 +170,10 @@ func (s *HTTP) callback(inConn net.Conn) {
|
|||||||
} else if *s.cfg.Always {
|
} else if *s.cfg.Always {
|
||||||
useProxy = true
|
useProxy = true
|
||||||
} else {
|
} else {
|
||||||
s.checker.Add(address)
|
k := s.Resolve(address)
|
||||||
|
s.checker.Add(k)
|
||||||
//var n, m uint
|
//var n, m uint
|
||||||
useProxy, _, _ = s.checker.IsBlocked(req.Host)
|
useProxy, _, _ = s.checker.IsBlocked(k)
|
||||||
//log.Printf("blocked ? : %v, %s , fail:%d ,success:%d", useProxy, address, n, m)
|
//log.Printf("blocked ? : %v, %s , fail:%d ,success:%d", useProxy, address, n, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,7 +216,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outConn, err = utils.ConnectHost(s.domainResolver.MustResolve(address), *s.cfg.Timeout)
|
outConn, err = utils.ConnectHost(s.Resolve(address), *s.cfg.Timeout)
|
||||||
}
|
}
|
||||||
tryCount++
|
tryCount++
|
||||||
if err == nil || tryCount > maxTryCount {
|
if err == nil || tryCount > maxTryCount {
|
||||||
@ -308,7 +309,7 @@ func (s *HTTP) ConnectSSH() (err error) {
|
|||||||
if s.sshClient != nil {
|
if s.sshClient != nil {
|
||||||
s.sshClient.Close()
|
s.sshClient.Close()
|
||||||
}
|
}
|
||||||
s.sshClient, err = ssh.Dial("tcp", s.domainResolver.MustResolve(*s.cfg.Parent), &config)
|
s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
|
||||||
<-s.lockChn
|
<-s.lockChn
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -322,7 +323,7 @@ func (s *HTTP) InitOutConnPool() {
|
|||||||
*s.cfg.KCPMethod,
|
*s.cfg.KCPMethod,
|
||||||
*s.cfg.KCPKey,
|
*s.cfg.KCPKey,
|
||||||
s.cfg.CertBytes, s.cfg.KeyBytes,
|
s.cfg.CertBytes, s.cfg.KeyBytes,
|
||||||
s.domainResolver.MustResolve(*s.cfg.Parent),
|
s.Resolve(*s.cfg.Parent),
|
||||||
*s.cfg.Timeout,
|
*s.cfg.Timeout,
|
||||||
*s.cfg.PoolSize,
|
*s.cfg.PoolSize,
|
||||||
*s.cfg.PoolSize*2,
|
*s.cfg.PoolSize*2,
|
||||||
@ -330,7 +331,11 @@ func (s *HTTP) InitOutConnPool() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (s *HTTP) InitBasicAuth() (err error) {
|
func (s *HTTP) InitBasicAuth() (err error) {
|
||||||
s.basicAuth = utils.NewBasicAuth()
|
if *s.cfg.DNSAddress != "" {
|
||||||
|
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
|
||||||
|
} else {
|
||||||
|
s.basicAuth = utils.NewBasicAuth(nil)
|
||||||
|
}
|
||||||
if *s.cfg.AuthURL != "" {
|
if *s.cfg.AuthURL != "" {
|
||||||
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
|
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
|
||||||
log.Printf("auth from %s", *s.cfg.AuthURL)
|
log.Printf("auth from %s", *s.cfg.AuthURL)
|
||||||
@ -364,7 +369,11 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
|
|||||||
}
|
}
|
||||||
if inPort == outPort {
|
if inPort == outPort {
|
||||||
var outIPs []net.IP
|
var outIPs []net.IP
|
||||||
|
if *s.cfg.DNSAddress != "" {
|
||||||
|
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
|
||||||
|
} else {
|
||||||
outIPs, err = net.LookupIP(outDomain)
|
outIPs, err = net.LookupIP(outDomain)
|
||||||
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, ip := range outIPs {
|
for _, ip := range outIPs {
|
||||||
if ip.String() == inIP {
|
if ip.String() == inIP {
|
||||||
@ -388,3 +397,9 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
func (s *HTTP) Resolve(address string) string {
|
||||||
|
if *s.cfg.DNSAddress == "" {
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
return s.domainResolver.MustResolve(address)
|
||||||
|
}
|
||||||
|
|||||||
@ -42,7 +42,7 @@ func (s *Socks) CheckArgs() {
|
|||||||
}
|
}
|
||||||
if *s.cfg.Parent != "" {
|
if *s.cfg.Parent != "" {
|
||||||
if *s.cfg.ParentType == "" {
|
if *s.cfg.ParentType == "" {
|
||||||
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>")
|
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
|
||||||
}
|
}
|
||||||
if *s.cfg.ParentType == "tls" {
|
if *s.cfg.ParentType == "tls" {
|
||||||
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
||||||
@ -90,7 +90,7 @@ func (s *Socks) InitService() {
|
|||||||
go func() {
|
go func() {
|
||||||
//循环检查ssh网络连通性
|
//循环检查ssh网络连通性
|
||||||
for {
|
for {
|
||||||
conn, err := utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout*2)
|
conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = conn.Write([]byte{0})
|
_, err = conn.Write([]byte{0})
|
||||||
}
|
}
|
||||||
@ -110,7 +110,6 @@ func (s *Socks) InitService() {
|
|||||||
if *s.cfg.ParentType == "ssh" {
|
if *s.cfg.ParentType == "ssh" {
|
||||||
log.Println("warn: socks udp not suppored for ssh")
|
log.Println("warn: socks udp not suppored for ssh")
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal)
|
s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal)
|
||||||
err := s.udpSC.ListenUDP(s.udpCallback)
|
err := s.udpSC.ListenUDP(s.udpCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -192,7 +191,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
|
|||||||
if parent == "" {
|
if parent == "" {
|
||||||
parent = *s.cfg.Parent
|
parent = *s.cfg.Parent
|
||||||
}
|
}
|
||||||
dstAddr, err := net.ResolveUDPAddr("udp", parent)
|
dstAddr, err := net.ResolveUDPAddr("udp", s.Resolve(parent))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("can't resolve address: %s", err)
|
log.Printf("can't resolve address: %s", err)
|
||||||
return
|
return
|
||||||
@ -248,7 +247,7 @@ func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
//本地代理
|
//本地代理
|
||||||
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port()))
|
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(s.Resolve(p.Host()), p.Port()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("can't resolve address: %s", err)
|
log.Printf("can't resolve address: %s", err)
|
||||||
return
|
return
|
||||||
@ -425,16 +424,17 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
|
|||||||
if utils.IsIternalIP(host) {
|
if utils.IsIternalIP(host) {
|
||||||
useProxy = false
|
useProxy = false
|
||||||
} else {
|
} else {
|
||||||
s.checker.Add(request.Addr())
|
k := s.Resolve(request.Addr())
|
||||||
useProxy, _, _ = s.checker.IsBlocked(request.Addr())
|
s.checker.Add(k)
|
||||||
|
useProxy, _, _ = s.checker.IsBlocked(k)
|
||||||
}
|
}
|
||||||
if useProxy {
|
if useProxy {
|
||||||
outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr())
|
outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr())
|
||||||
} else {
|
} else {
|
||||||
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
|
outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
|
outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
|
||||||
useProxy = false
|
useProxy = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -471,12 +471,12 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n
|
|||||||
case "tcp":
|
case "tcp":
|
||||||
if *s.cfg.ParentType == "tls" {
|
if *s.cfg.ParentType == "tls" {
|
||||||
var _outConn tls.Conn
|
var _outConn tls.Conn
|
||||||
_outConn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
|
_outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
|
||||||
outConn = net.Conn(&_outConn)
|
outConn = net.Conn(&_outConn)
|
||||||
} else if *s.cfg.ParentType == "kcp" {
|
} else if *s.cfg.ParentType == "kcp" {
|
||||||
outConn, err = utils.ConnectKCPHost(*s.cfg.Parent, *s.cfg.KCPMethod, *s.cfg.KCPKey)
|
outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), *s.cfg.KCPMethod, *s.cfg.KCPKey)
|
||||||
} else {
|
} else {
|
||||||
outConn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout)
|
outConn, err = utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("connect fail,%s", err)
|
err = fmt.Errorf("connect fail,%s", err)
|
||||||
@ -566,12 +566,16 @@ func (s *Socks) ConnectSSH() (err error) {
|
|||||||
if s.sshClient != nil {
|
if s.sshClient != nil {
|
||||||
s.sshClient.Close()
|
s.sshClient.Close()
|
||||||
}
|
}
|
||||||
s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config)
|
s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
|
||||||
<-s.lockChn
|
<-s.lockChn
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (s *Socks) InitBasicAuth() (err error) {
|
func (s *Socks) InitBasicAuth() (err error) {
|
||||||
s.basicAuth = utils.NewBasicAuth()
|
if *s.cfg.DNSAddress != "" {
|
||||||
|
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver)
|
||||||
|
} else {
|
||||||
|
s.basicAuth = utils.NewBasicAuth(nil)
|
||||||
|
}
|
||||||
if *s.cfg.AuthURL != "" {
|
if *s.cfg.AuthURL != "" {
|
||||||
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
|
s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry)
|
||||||
log.Printf("auth from %s", *s.cfg.AuthURL)
|
log.Printf("auth from %s", *s.cfg.AuthURL)
|
||||||
@ -605,7 +609,11 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
|
|||||||
}
|
}
|
||||||
if inPort == outPort {
|
if inPort == outPort {
|
||||||
var outIPs []net.IP
|
var outIPs []net.IP
|
||||||
|
if *s.cfg.DNSAddress != "" {
|
||||||
|
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
|
||||||
|
} else {
|
||||||
outIPs, err = net.LookupIP(outDomain)
|
outIPs, err = net.LookupIP(outDomain)
|
||||||
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, ip := range outIPs {
|
for _, ip := range outIPs {
|
||||||
if ip.String() == inIP {
|
if ip.String() == inIP {
|
||||||
@ -629,3 +637,9 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
func (s *Socks) Resolve(address string) string {
|
||||||
|
if *s.cfg.DNSAddress == "" {
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
return s.domainResolver.MustResolve(address)
|
||||||
|
}
|
||||||
|
|||||||
@ -451,7 +451,7 @@ func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, er
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(host) == 1 {
|
if len(host) == 1 && host[0] != "" {
|
||||||
req.Host = host[0]
|
req.Host = host[0]
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
|
|||||||
@ -173,11 +173,13 @@ type BasicAuth struct {
|
|||||||
authOkCode int
|
authOkCode int
|
||||||
authTimeout int
|
authTimeout int
|
||||||
authRetry int
|
authRetry int
|
||||||
|
dns *DomainResolver
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBasicAuth() BasicAuth {
|
func NewBasicAuth(dns *DomainResolver) BasicAuth {
|
||||||
return BasicAuth{
|
return BasicAuth{
|
||||||
data: NewConcurrentMap(),
|
data: NewConcurrentMap(),
|
||||||
|
dns: dns,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) {
|
func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) {
|
||||||
@ -241,18 +243,27 @@ func (ba *BasicAuth) checkFromURL(userpass, ip, target string) (err error) {
|
|||||||
if len(u) != 2 {
|
if len(u) != 2 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
URL := ba.authURL
|
URL := ba.authURL
|
||||||
if strings.Contains(URL, "?") {
|
if strings.Contains(URL, "?") {
|
||||||
URL += "&"
|
URL += "&"
|
||||||
} else {
|
} else {
|
||||||
URL += "?"
|
URL += "?"
|
||||||
}
|
}
|
||||||
URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, target)
|
URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&target=%s", u[0], u[1], ip, url.QueryEscape(target))
|
||||||
|
getURL := URL
|
||||||
|
var domain string
|
||||||
|
if ba.dns != nil {
|
||||||
|
_url, _ := url.Parse(ba.authURL)
|
||||||
|
domain = _url.Host
|
||||||
|
domainIP := ba.dns.MustResolve(domain)
|
||||||
|
getURL = strings.Replace(URL, domain, domainIP, 1)
|
||||||
|
}
|
||||||
var code int
|
var code int
|
||||||
var tryCount = 0
|
var tryCount = 0
|
||||||
var body []byte
|
var body []byte
|
||||||
for tryCount <= ba.authRetry {
|
for tryCount <= ba.authRetry {
|
||||||
body, code, err = HttpGet(URL, ba.authTimeout)
|
body, code, err = HttpGet(getURL, ba.authTimeout, domain)
|
||||||
if err == nil && code == ba.authOkCode {
|
if err == nil && code == ba.authOkCode {
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@ -855,6 +866,10 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if net.ParseIP(domain) != nil {
|
||||||
|
ip = domain
|
||||||
|
return
|
||||||
|
}
|
||||||
item, ok := a.data.Get(domain)
|
item, ok := a.data.Get(domain)
|
||||||
if ok {
|
if ok {
|
||||||
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
|
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
|
||||||
@ -868,10 +883,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
|||||||
}
|
}
|
||||||
a.data.Set(domain, item)
|
a.data.Set(domain, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
if net.ParseIP(domain) != nil {
|
|
||||||
return domain, nil
|
|
||||||
}
|
|
||||||
c := new(dns.Client)
|
c := new(dns.Client)
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||||
@ -880,7 +891,6 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
|||||||
if r == nil {
|
if r == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Rcode != dns.RcodeSuccess {
|
if r.Rcode != dns.RcodeSuccess {
|
||||||
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
|
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user