115 lines
2.5 KiB
Go
115 lines
2.5 KiB
Go
package dnsx
|
|
|
|
import (
|
|
"fmt"
|
|
logger "log"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"bitbucket.org/snail/proxy/utils/mapx"
|
|
dns "github.com/miekg/dns"
|
|
)
|
|
|
|
type DomainResolver struct {
|
|
ttl int
|
|
dnsAddrress string
|
|
data mapx.ConcurrentMap
|
|
log *logger.Logger
|
|
}
|
|
type DomainResolverItem struct {
|
|
ip string
|
|
domain string
|
|
expiredAt int64
|
|
}
|
|
|
|
func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver {
|
|
return DomainResolver{
|
|
ttl: ttl,
|
|
dnsAddrress: dnsAddrress,
|
|
data: mapx.NewConcurrentMap(),
|
|
log: log,
|
|
}
|
|
}
|
|
func (a *DomainResolver) DnsAddress() (address string) {
|
|
address = a.dnsAddrress
|
|
return
|
|
}
|
|
func (a *DomainResolver) MustResolve(address string) (ip string) {
|
|
ip, _ = a.Resolve(address)
|
|
return
|
|
}
|
|
func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
|
domain := address
|
|
port := ""
|
|
fromCache := "false"
|
|
defer func() {
|
|
if port != "" {
|
|
ip = net.JoinHostPort(ip, port)
|
|
}
|
|
a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache)
|
|
//a.PrintData()
|
|
}()
|
|
if strings.Contains(domain, ":") {
|
|
domain, port, err = net.SplitHostPort(domain)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
if net.ParseIP(domain) != nil {
|
|
ip = domain
|
|
fromCache = "ip ignore"
|
|
return
|
|
}
|
|
item, ok := a.data.Get(domain)
|
|
if ok {
|
|
//log.Println("find ", domain)
|
|
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
|
|
ip = (*item.(*DomainResolverItem)).ip
|
|
fromCache = "true"
|
|
//log.Println("from cache ", domain)
|
|
return
|
|
}
|
|
} else {
|
|
item = &DomainResolverItem{
|
|
domain: domain,
|
|
}
|
|
|
|
}
|
|
c := new(dns.Client)
|
|
c.DialTimeout = time.Millisecond * 5000
|
|
c.ReadTimeout = time.Millisecond * 5000
|
|
c.WriteTimeout = time.Millisecond * 5000
|
|
m := new(dns.Msg)
|
|
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
|
m.RecursionDesired = true
|
|
r, _, err := c.Exchange(m, a.dnsAddrress)
|
|
if r == nil {
|
|
return
|
|
}
|
|
if r.Rcode != dns.RcodeSuccess {
|
|
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
|
|
return
|
|
}
|
|
for _, answer := range r.Answer {
|
|
if answer.Header().Rrtype == dns.TypeA {
|
|
info := strings.Fields(answer.String())
|
|
if len(info) >= 5 {
|
|
ip = info[4]
|
|
_item := item.(*DomainResolverItem)
|
|
(*_item).expiredAt = time.Now().Unix() + int64(a.ttl)
|
|
(*_item).ip = ip
|
|
a.data.Set(domain, item)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
func (a *DomainResolver) PrintData() {
|
|
for k, item := range a.data.Items() {
|
|
d := item.(*DomainResolverItem)
|
|
a.log.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt)
|
|
}
|
|
}
|