add dns support

Signed-off-by: arraykeys@gmail.com <arraykeys@gmail.com>
This commit is contained in:
arraykeys@gmail.com
2018-02-26 18:44:14 +08:00
parent 983912e44e
commit dee517217e
8 changed files with 139 additions and 21 deletions

View File

@ -650,6 +650,8 @@ Then access to the local 8080 port is access to the proxy port 38080 on the VPS,
### How to use the source code? ### How to use the source code?
use command cd to enter your go SRC directory and then use command cd to enter your go SRC directory and then
mkdir snail007
cd snail007
execute `git clone https://github.com/snail007/goproxy.git ./proxy` execute `git clone https://github.com/snail007/goproxy.git ./proxy`
Direct compilation: `go build` Direct compilation: `go build`
execution: `go run *.go` execution: `go run *.go`

View File

@ -656,8 +656,9 @@ KCP协议需要-B参数设置一个密码用于加密解密数据
- 欢迎加群反馈... - 欢迎加群反馈...
### 如何使用源码? ### 如何使用源码?
建议go1.8,不保证>=1.9能用. 建议go1.8.5,不保证>=1.9能用.
cd进入你的go src目录,然后git clone https://github.com/snail007/goproxy.git ./proxy 即可. cd进入你的go src目录,新建文件夹snail007,
cd进入snail007,然后git clone https://github.com/snail007/goproxy.git ./proxy 即可.
编译直接:go build 编译直接:go build
运行: go run *.go 运行: go run *.go
utils是工具包,service是具体的每个服务类. utils是工具包,service是具体的每个服务类.

View File

@ -76,6 +76,8 @@ func initConfig() (err error) {
httpArgs.AuthURLTimeout = http.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int() httpArgs.AuthURLTimeout = http.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int()
httpArgs.AuthURLOkCode = http.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int() httpArgs.AuthURLOkCode = http.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int()
httpArgs.AuthURLRetry = http.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("1").Int() httpArgs.AuthURLRetry = http.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("1").Int()
httpArgs.DNSAddress = http.Flag("dns-address", "if set this, proxy will use this dns for resolve doamin").Short('q').Default("").String()
httpArgs.DNSTTL = http.Flag("dns-ttl", "caching seconds of dns query result").Short('e').Default("300").Int()
//########tcp######### //########tcp#########
tcp := app.Command("tcp", "proxy on tcp mode") tcp := app.Command("tcp", "proxy on tcp mode")
@ -182,6 +184,8 @@ func initConfig() (err error) {
socksArgs.AuthURLTimeout = socks.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int() socksArgs.AuthURLTimeout = socks.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int()
socksArgs.AuthURLOkCode = socks.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int() socksArgs.AuthURLOkCode = socks.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int()
socksArgs.AuthURLRetry = socks.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("0").Int() socksArgs.AuthURLRetry = socks.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("0").Int()
socksArgs.DNSAddress = socks.Flag("dns-address", "if set this, proxy will use this dns for resolve doamin").Short('q').Default("").String()
socksArgs.DNSTTL = socks.Flag("dns-ttl", "caching seconds of dns query result").Short('e').Default("300").Int()
//parse args //parse args
serviceName := kingpin.MustParse(app.Parse(os.Args[1:])) serviceName := kingpin.MustParse(app.Parse(os.Args[1:]))

View File

@ -138,6 +138,8 @@ type HTTPArgs struct {
KCPMethod *string KCPMethod *string
KCPKey *string KCPKey *string
LocalIPS *[]string LocalIPS *[]string
DNSAddress *string
DNSTTL *int
} }
type UDPArgs struct { type UDPArgs struct {
Parent *string Parent *string
@ -182,6 +184,8 @@ type SocksArgs struct {
UDPParent *string UDPParent *string
UDPLocal *string UDPLocal *string
LocalIPS *[]string LocalIPS *[]string
DNSAddress *string
DNSTTL *int
} }
func (a *TCPArgs) Protocol() string { func (a *TCPArgs) Protocol() string {

View File

@ -16,12 +16,13 @@ import (
) )
type HTTP struct { type HTTP struct {
outPool utils.OutPool outPool utils.OutPool
cfg HTTPArgs cfg HTTPArgs
checker utils.Checker checker utils.Checker
basicAuth utils.BasicAuth basicAuth utils.BasicAuth
sshClient *ssh.Client sshClient *ssh.Client
lockChn chan bool lockChn chan bool
domainResolver utils.DomainResolver
} }
func NewHTTP() Service { func NewHTTP() Service {
@ -74,6 +75,9 @@ func (s *HTTP) InitService() {
if *s.cfg.Parent != "" { if *s.cfg.Parent != "" {
s.checker = utils.NewChecker(*s.cfg.HTTPTimeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct) s.checker = utils.NewChecker(*s.cfg.HTTPTimeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct)
} }
if *s.cfg.DNSAddress != "" {
(*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL)
}
if *s.cfg.ParentType == "ssh" { if *s.cfg.ParentType == "ssh" {
err := s.ConnectSSH() err := s.ConnectSSH()
if err != nil { if err != nil {
@ -82,7 +86,7 @@ func (s *HTTP) InitService() {
go func() { go func() {
//循环检查ssh网络连通性 //循环检查ssh网络连通性
for { for {
conn, err := utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout*2) conn, err := utils.ConnectHost(s.domainResolver.MustResolve(*s.cfg.Parent), *s.cfg.Timeout*2)
if err == nil { if err == nil {
_, err = conn.Write([]byte{0}) _, err = conn.Write([]byte{0})
} }
@ -211,7 +215,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
} }
} }
} else { } else {
outConn, err = utils.ConnectHost(address, *s.cfg.Timeout) outConn, err = utils.ConnectHost(s.domainResolver.MustResolve(address), *s.cfg.Timeout)
} }
tryCount++ tryCount++
if err == nil || tryCount > maxTryCount { if err == nil || tryCount > maxTryCount {
@ -304,7 +308,7 @@ func (s *HTTP) ConnectSSH() (err error) {
if s.sshClient != nil { if s.sshClient != nil {
s.sshClient.Close() s.sshClient.Close()
} }
s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config) s.sshClient, err = ssh.Dial("tcp", s.domainResolver.MustResolve(*s.cfg.Parent), &config)
<-s.lockChn <-s.lockChn
return return
} }
@ -318,7 +322,7 @@ func (s *HTTP) InitOutConnPool() {
*s.cfg.KCPMethod, *s.cfg.KCPMethod,
*s.cfg.KCPKey, *s.cfg.KCPKey,
s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CertBytes, s.cfg.KeyBytes,
*s.cfg.Parent, s.domainResolver.MustResolve(*s.cfg.Parent),
*s.cfg.Timeout, *s.cfg.Timeout,
*s.cfg.PoolSize, *s.cfg.PoolSize,
*s.cfg.PoolSize*2, *s.cfg.PoolSize*2,

View File

@ -6,10 +6,10 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"runtime/debug"
"snail007/proxy/utils" "snail007/proxy/utils"
"snail007/proxy/utils/aes" "snail007/proxy/utils/aes"
"snail007/proxy/utils/socks" "snail007/proxy/utils/socks"
"runtime/debug"
"strings" "strings"
"time" "time"
@ -17,12 +17,13 @@ import (
) )
type Socks struct { type Socks struct {
cfg SocksArgs cfg SocksArgs
checker utils.Checker checker utils.Checker
basicAuth utils.BasicAuth basicAuth utils.BasicAuth
sshClient *ssh.Client sshClient *ssh.Client
lockChn chan bool lockChn chan bool
udpSC utils.ServerChannel udpSC utils.ServerChannel
domainResolver utils.DomainResolver
} }
func NewSocks() Service { func NewSocks() Service {
@ -77,6 +78,9 @@ func (s *Socks) CheckArgs() {
} }
func (s *Socks) InitService() { func (s *Socks) InitService() {
s.InitBasicAuth() s.InitBasicAuth()
if *s.cfg.DNSAddress != "" {
(*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL)
}
s.checker = utils.NewChecker(*s.cfg.Timeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct) s.checker = utils.NewChecker(*s.cfg.Timeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct)
if *s.cfg.ParentType == "ssh" { if *s.cfg.ParentType == "ssh" {
err := s.ConnectSSH() err := s.ConnectSSH()

View File

@ -431,7 +431,7 @@ func GetKCPBlock(method, key string) (block kcp.BlockCrypt) {
} }
return return
} }
func HttpGet(URL string, timeout int) (body []byte, code int, err error) { func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, err error) {
var tr *http.Transport var tr *http.Transport
var client *http.Client var client *http.Client
conf := &tls.Config{ conf := &tls.Config{
@ -445,7 +445,16 @@ func HttpGet(URL string, timeout int) (body []byte, code int, err error) {
client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr} client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
} }
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
resp, err := client.Get(URL)
//resp, err := client.Get(URL)
req, err := http.NewRequest("GET", URL, nil)
if err != nil {
return
}
if len(host) == 1 {
req.Host = host[0]
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return return
} }

View File

@ -15,6 +15,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"src/github.com/miekg/dns"
) )
type Checker struct { type Checker struct {
@ -815,3 +817,91 @@ func (c *ClientKeyRouter) GetKey() string {
} }
} }
type DomainResolver struct {
ttl int
dnsAddrress string
data ConcurrentMap
}
type DomainResolverItem struct {
ip string
domain string
expiredAt int64
}
func NewDomainResolver(dnsAddrress string, ttl int) DomainResolver {
return DomainResolver{
ttl: ttl,
dnsAddrress: dnsAddrress,
data: NewConcurrentMap(),
}
}
func (a *DomainResolver) MustResolve(address string) (ip string) {
ip, _ = a.Resolve(address)
return
}
func (a *DomainResolver) Resolve(address string) (ip string, err error) {
domain := address
port := ""
defer func() {
if port != "" {
ip = net.JoinHostPort(ip, port)
}
}()
if strings.Contains(domain, ":") {
domain, port, err = net.SplitHostPort(domain)
if err != nil {
return
}
}
item, ok := a.data.Get(domain)
if ok {
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
ip = (*item.(*DomainResolverItem)).ip
return
}
} else {
item = &DomainResolverItem{
domain: domain,
}
a.data.Set(domain, item)
}
if net.ParseIP(domain) != nil {
return domain, nil
}
c := new(dns.Client)
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
m.RecursionDesired = true
r, _, err := c.Exchange(m, a.dnsAddrress)
if r == nil {
return
}
if r.Rcode != dns.RcodeSuccess {
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
return
}
for _, answer := range r.Answer {
if answer.Header().Rrtype == dns.TypeA {
info := strings.Fields(answer.String())
if len(info) >= 5 {
ip = info[4]
_item := item.(*DomainResolverItem)
(*_item).expiredAt = time.Now().Unix() + int64(a.ttl)
(*_item).ip = ip
}
return
}
}
return
}
func (a *DomainResolver) PrintData() {
for k, item := range a.data.Items() {
d := item.(*DomainResolverItem)
fmt.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt)
}
}