From 7aeef3f8ba0d94eb9d2a922a5dc0fe4930d4899b Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Mon, 7 May 2018 17:14:20 +0800 Subject: [PATCH] Signed-off-by: arraykeys@gmail.com --- CHANGELOG | 6 +++++- config.go | 3 +++ services/args.go | 3 +++ services/http.go | 16 ++++++++-------- services/mux_bridge.go | 4 ++-- services/mux_client.go | 2 +- services/mux_server.go | 4 ++-- services/socks.go | 16 ++++++++-------- services/sps.go | 29 ++++++++++++++++++++++------- services/tcp.go | 6 +++--- services/tunnel_bridge.go | 4 ++-- services/tunnel_client.go | 2 +- services/tunnel_server.go | 4 ++-- services/udp.go | 2 +- utils/functions.go | 18 +++++++++--------- utils/serve-channel.go | 27 +++++++++++++++------------ utils/structs.go | 38 ++++++++++++++++++++++++-------------- 17 files changed, 111 insertions(+), 73 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index b2e0409..17d8c4a 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,7 +1,11 @@ proxy更新日志 v4.8 1.优化了SPS连接HTTP上级的指令,避免了某些代理不响应的问题. - +2.SPS功能增加了参数: +--disable-http:禁用http(s)代理 +--disable-socks:禁用socks代理. +默认都是false(开启). +3.重构了部分代码的日志部分,保证了日志按着预期输出. v4.7 diff --git a/config.go b/config.go index f43728f..e83ec32 100755 --- a/config.go +++ b/config.go @@ -244,6 +244,9 @@ func initConfig() (err error) { spsArgs.ParentCompress = sps.Flag("parent-compress", "auto compress/decompress data on parent connection").Short('M').Default("false").Bool() spsArgs.SSMethod = sps.Hidden().Flag("ss-method", "").Short('h').Default("aes-256-cfb").String() spsArgs.SSKey = sps.Hidden().Flag("ss-key", "").Short('j').Default("sspassword").String() + spsArgs.DisableHTTP = sps.Flag("disable-http", "disable http(s) proxy").Default("false").Bool() + spsArgs.DisableSocks5 = sps.Flag("disable-socks", "disable socks proxy").Default("false").Bool() + spsArgs.DisableSS = sps.Hidden().Flag("disable-ss", "").Default("false").Bool() //parse args serviceName := kingpin.MustParse(app.Parse(os.Args[1:])) diff --git a/services/args.go b/services/args.go index 59c4e0e..c54c5c8 100644 --- a/services/args.go +++ b/services/args.go @@ -234,6 +234,9 @@ type SPSArgs struct { ParentCompress *bool SSMethod *string SSKey *string + DisableHTTP *bool + DisableSocks5 *bool + DisableSS *bool } func (a *SPSArgs) Protocol() string { diff --git a/services/http.go b/services/http.go index 5584585..abb8f86 100644 --- a/services/http.go +++ b/services/http.go @@ -97,10 +97,10 @@ func (s *HTTP) CheckArgs() (err error) { func (s *HTTP) InitService() (err error) { s.InitBasicAuth() if *s.cfg.Parent != "" { - s.checker = utils.NewChecker(*s.cfg.HTTPTimeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct) + s.checker = utils.NewChecker(*s.cfg.HTTPTimeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct, s.log) } if *s.cfg.DNSAddress != "" { - (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL) + (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log) } if *s.cfg.ParentType == "ssh" { err = s.ConnectSSH() @@ -178,13 +178,13 @@ func (s *HTTP) Start(args interface{}, log *logger.Logger) (err error) { if addr != "" { host, port, _ := net.SplitHostPort(addr) p, _ := strconv.Atoi(port) - sc := utils.NewServerChannel(host, p) + sc := utils.NewServerChannel(host, p, s.log) if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.callback) } else if *s.cfg.LocalType == TYPE_TLS { err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes, s.callback) } else if *s.cfg.LocalType == TYPE_KCP { - err = sc.ListenKCP(s.cfg.KCP, s.callback) + err = sc.ListenKCP(s.cfg.KCP, s.callback, s.log) } if err != nil { return @@ -215,7 +215,7 @@ func (s *HTTP) callback(inConn net.Conn) { } var err interface{} var req utils.HTTPRequest - req, err = utils.NewHTTPRequest(&inConn, 4096, s.IsBasicAuth(), &s.basicAuth) + req, err = utils.NewHTTPRequest(&inConn, 4096, s.IsBasicAuth(), &s.basicAuth, s.log) if err != nil { if err != io.EOF { s.log.Printf("decoder error , from %s, ERR:%s", inConn.RemoteAddr(), err) @@ -321,7 +321,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut utils.IoBind((*inConn), outConn, func(err interface{}) { s.log.Printf("conn %s - %s released [%s]", inAddr, outAddr, req.Host) s.userConns.Remove(inAddr) - }) + }, s.log) s.log.Printf("conn %s - %s connected [%s]", inAddr, outAddr, req.Host) if c, ok := s.userConns.Get(inAddr); ok { (*c.(*net.Conn)).Close() @@ -403,9 +403,9 @@ func (s *HTTP) InitOutConnPool() { } func (s *HTTP) InitBasicAuth() (err error) { if *s.cfg.DNSAddress != "" { - s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver) + s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver, s.log) } else { - s.basicAuth = utils.NewBasicAuth(nil) + s.basicAuth = utils.NewBasicAuth(nil, s.log) } if *s.cfg.AuthURL != "" { s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) diff --git a/services/mux_bridge.go b/services/mux_bridge.go index 63302a6..ad8921c 100644 --- a/services/mux_bridge.go +++ b/services/mux_bridge.go @@ -90,13 +90,13 @@ func (s *MuxBridge) Start(args interface{}, log *logger.Logger) (err error) { host, port, _ := net.SplitHostPort(*s.cfg.Local) p, _ := strconv.Atoi(port) - sc := utils.NewServerChannel(host, p) + sc := utils.NewServerChannel(host, p, s.log) if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.handler) } else if *s.cfg.LocalType == TYPE_TLS { err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.handler) } else if *s.cfg.LocalType == TYPE_KCP { - err = sc.ListenKCP(s.cfg.KCP, s.handler) + err = sc.ListenKCP(s.cfg.KCP, s.handler, s.log) } if err != nil { return diff --git a/services/mux_client.go b/services/mux_client.go index 54eed3c..9b3efcb 100644 --- a/services/mux_client.go +++ b/services/mux_client.go @@ -300,6 +300,6 @@ func (s *MuxClient) ServeConn(inConn *smux.Stream, localAddr, ID string) { } else { utils.IoBind(inConn, outConn, func(err interface{}) { s.log.Printf("stream %s released", ID) - }) + }, s.log) } } diff --git a/services/mux_server.go b/services/mux_server.go index 4fd9b1e..776429c 100644 --- a/services/mux_server.go +++ b/services/mux_server.go @@ -200,7 +200,7 @@ func (s *MuxServer) Start(args interface{}, log *logger.Logger) (err error) { } host, port, _ := net.SplitHostPort(*s.cfg.Local) p, _ := strconv.Atoi(port) - s.sc = utils.NewServerChannel(host, p) + s.sc = utils.NewServerChannel(host, p, s.log) if *s.cfg.IsUDP { err = s.sc.ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) { s.udpChn <- MuxUDPItem{ @@ -258,7 +258,7 @@ func (s *MuxServer) Start(args interface{}, log *logger.Logger) (err error) { } else { utils.IoBind(inConn, outConn, func(err interface{}) { s.log.Printf("%s stream %s released", *s.cfg.Key, ID) - }) + }, s.log) } }) if err != nil { diff --git a/services/socks.go b/services/socks.go index d585408..16cccc5 100644 --- a/services/socks.go +++ b/services/socks.go @@ -110,9 +110,9 @@ func (s *Socks) CheckArgs() (err error) { func (s *Socks) InitService() (err error) { s.InitBasicAuth() if *s.cfg.DNSAddress != "" { - (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL) + (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log) } - s.checker = utils.NewChecker(*s.cfg.Timeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct) + s.checker = utils.NewChecker(*s.cfg.Timeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct, s.log) if *s.cfg.ParentType == "ssh" { e := s.ConnectSSH() if e != nil { @@ -147,7 +147,7 @@ func (s *Socks) InitService() (err error) { if *s.cfg.ParentType == "ssh" { s.log.Printf("warn: socks udp not suppored for ssh") } else { - s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal) + s.udpSC = utils.NewServerChannelHost(*s.cfg.UDPLocal, s.log) e := s.udpSC.ListenUDP(s.udpCallback) if e != nil { err = fmt.Errorf("init udp service fail, ERR: %s", e) @@ -197,13 +197,13 @@ func (s *Socks) Start(args interface{}, log *logger.Logger) (err error) { if *s.cfg.UDPParent != "" { s.log.Printf("use socks udp parent %s", *s.cfg.UDPParent) } - sc := utils.NewServerChannelHost(*s.cfg.Local) + sc := utils.NewServerChannelHost(*s.cfg.Local, s.log) if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.socksConnCallback) } else if *s.cfg.LocalType == TYPE_TLS { err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.socksConnCallback) } else if *s.cfg.LocalType == TYPE_KCP { - err = sc.ListenKCP(s.cfg.KCP, s.socksConnCallback) + err = sc.ListenKCP(s.cfg.KCP, s.socksConnCallback, s.log) } if err != nil { return @@ -557,7 +557,7 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque s.log.Printf("conn %s - %s connected", inAddr, request.Addr()) utils.IoBind(*inConn, outConn, func(err interface{}) { s.log.Printf("conn %s - %s released", inAddr, request.Addr()) - }) + }, s.log) if c, ok := s.userConns.Get(inAddr); ok { (*c.(*net.Conn)).Close() s.userConns.Remove(inAddr) @@ -689,9 +689,9 @@ func (s *Socks) ConnectSSH() (err error) { } func (s *Socks) InitBasicAuth() (err error) { if *s.cfg.DNSAddress != "" { - s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver) + s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver, s.log) } else { - s.basicAuth = utils.NewBasicAuth(nil) + s.basicAuth = utils.NewBasicAuth(nil, s.log) } if *s.cfg.AuthURL != "" { s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) diff --git a/services/sps.go b/services/sps.go index 5defff6..42d4a8d 100644 --- a/services/sps.go +++ b/services/sps.go @@ -66,7 +66,7 @@ func (s *SPS) CheckArgs() (err error) { func (s *SPS) InitService() (err error) { s.InitOutConnPool() if *s.cfg.DNSAddress != "" { - (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL) + (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log) } err = s.InitBasicAuth() if *s.cfg.SSMethod != "" && *s.cfg.SSKey != "" { @@ -128,13 +128,13 @@ func (s *SPS) Start(args interface{}, log *logger.Logger) (err error) { if addr != "" { host, port, _ := net.SplitHostPort(*s.cfg.Local) p, _ := strconv.Atoi(port) - sc := utils.NewServerChannel(host, p) + sc := utils.NewServerChannel(host, p, s.log) if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.callback) } else if *s.cfg.LocalType == TYPE_TLS { err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes, s.callback) } else if *s.cfg.LocalType == TYPE_KCP { - err = sc.ListenKCP(s.cfg.KCP, s.callback) + err = sc.ListenKCP(s.cfg.KCP, s.callback, s.log) } if err != nil { return @@ -206,6 +206,10 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) { var forwardBytes []byte //fmt.Printf("%v", header) if utils.IsSocks5(h) { + if *s.cfg.DisableSocks5 { + (*inConn).Close() + return + } //socks5 server var serverConn *socks.ServerConn if s.IsBasicAuth() { @@ -219,6 +223,10 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) { address = serverConn.Target() auth = serverConn.AuthData() } else if utils.IsHTTP(h) { + if *s.cfg.DisableHTTP { + (*inConn).Close() + return + } //http var request utils.HTTPRequest (*inConn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) @@ -254,8 +262,14 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) { } } else { //ss + if *s.cfg.DisableSS { + (*inConn).Close() + return + } + (*inConn).SetDeadline(time.Now().Add(time.Second * 5)) ssConn := ss.NewConn(*inConn, s.cipher.Copy()) address, err = ss.GetRequest(ssConn) + (*inConn).SetDeadline(time.Time{}) if err != nil { return } @@ -266,8 +280,9 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) { } *inConn = ssConn } - if err != nil { + if err != nil || address == "" { s.log.Printf("unknown request from: %s,%s", (*inConn).RemoteAddr(), string(h)) + (*inConn).Close() utils.CloseConn(inConn) err = errors.New("unknown request") return @@ -363,7 +378,7 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) { utils.IoBind((*inConn), outConn, func(err interface{}) { s.log.Printf("conn %s - %s released", inAddr, outAddr) s.userConns.Remove(inAddr) - }) + }, s.log) s.log.Printf("conn %s - %s connected", inAddr, outAddr) if c, ok := s.userConns.Get(inAddr); ok { (*c.(*net.Conn)).Close() @@ -373,9 +388,9 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) { } func (s *SPS) InitBasicAuth() (err error) { if *s.cfg.DNSAddress != "" { - s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver) + s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver, s.log) } else { - s.basicAuth = utils.NewBasicAuth(nil) + s.basicAuth = utils.NewBasicAuth(nil, s.log) } if *s.cfg.AuthURL != "" { s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) diff --git a/services/tcp.go b/services/tcp.go index 52d436c..eab2830 100644 --- a/services/tcp.go +++ b/services/tcp.go @@ -84,14 +84,14 @@ func (s *TCP) Start(args interface{}, log *logger.Logger) (err error) { s.log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent) host, port, _ := net.SplitHostPort(*s.cfg.Local) p, _ := strconv.Atoi(port) - sc := utils.NewServerChannel(host, p) + sc := utils.NewServerChannel(host, p, s.log) if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.callback) } else if *s.cfg.LocalType == TYPE_TLS { err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback) } else if *s.cfg.LocalType == TYPE_KCP { - err = sc.ListenKCP(s.cfg.KCP, s.callback) + err = sc.ListenKCP(s.cfg.KCP, s.callback, s.log) } if err != nil { return @@ -143,7 +143,7 @@ func (s *TCP) OutToTCP(inConn *net.Conn) (err error) { utils.IoBind((*inConn), outConn, func(err interface{}) { s.log.Printf("conn %s - %s released", inAddr, outAddr) s.userConns.Remove(inAddr) - }) + }, s.log) s.log.Printf("conn %s - %s connected", inAddr, outAddr) if c, ok := s.userConns.Get(inAddr); ok { (*c.(*net.Conn)).Close() diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index 952286d..333523d 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -73,7 +73,7 @@ func (s *TunnelBridge) Start(args interface{}, log *logger.Logger) (err error) { } host, port, _ := net.SplitHostPort(*s.cfg.Local) p, _ := strconv.Atoi(port) - sc := utils.NewServerChannel(host, p) + sc := utils.NewServerChannel(host, p, s.log) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback) if err != nil { @@ -173,7 +173,7 @@ func (s *TunnelBridge) callback(inConn net.Conn) { // s.cmClient.RemoveOne(key, ID) // s.cmServer.RemoveOne(serverID, ID) s.log.Printf("conn %s released", ID) - }) + }, s.log) // s.cmClient.Add(key, ID, &inConn) s.log.Printf("conn %s created", ID) diff --git a/services/tunnel_client.go b/services/tunnel_client.go index ef82024..29af4be 100644 --- a/services/tunnel_client.go +++ b/services/tunnel_client.go @@ -284,7 +284,7 @@ func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) { utils.IoBind(inConn, outConn, func(err interface{}) { s.log.Printf("conn %s released", ID) s.userConns.Remove(inAddr) - }) + }, s.log) if c, ok := s.userConns.Get(inAddr); ok { (*c.(*net.Conn)).Close() } diff --git a/services/tunnel_server.go b/services/tunnel_server.go index 3e84199..6c4086e 100644 --- a/services/tunnel_server.go +++ b/services/tunnel_server.go @@ -206,7 +206,7 @@ func (s *TunnelServer) Start(args interface{}, log *logger.Logger) (err error) { } host, port, _ := net.SplitHostPort(*s.cfg.Local) p, _ := strconv.Atoi(port) - s.sc = utils.NewServerChannel(host, p) + s.sc = utils.NewServerChannel(host, p, s.log) if *s.cfg.IsUDP { err = s.sc.ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) { s.udpChn <- UDPItem{ @@ -246,7 +246,7 @@ func (s *TunnelServer) Start(args interface{}, log *logger.Logger) (err error) { utils.IoBind(inConn, outConn, func(err interface{}) { s.userConns.Remove(inAddr) s.log.Printf("%s conn %s released", *s.cfg.Key, ID) - }) + }, s.log) if c, ok := s.userConns.Get(inAddr); ok { (*c.(*net.Conn)).Close() } diff --git a/services/udp.go b/services/udp.go index 240eb57..57cdd59 100644 --- a/services/udp.go +++ b/services/udp.go @@ -84,7 +84,7 @@ func (s *UDP) Start(args interface{}, log *logger.Logger) (err error) { } host, port, _ := net.SplitHostPort(*s.cfg.Local) p, _ := strconv.Atoi(port) - sc := utils.NewServerChannel(host, p) + sc := utils.NewServerChannel(host, p, s.log) s.sc = &sc err = sc.ListenUDP(s.callback) if err != nil { diff --git a/utils/functions.go b/utils/functions.go index 4774774..21ca84f 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -12,7 +12,7 @@ import ( "fmt" "io" "io/ioutil" - "log" + logger "log" "math/rand" "net" "net/http" @@ -32,7 +32,7 @@ import ( kcp "github.com/xtaci/kcp-go" ) -func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{})) { +func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) { go func() { defer func() { if err := recover(); err != nil { @@ -222,7 +222,7 @@ func Keygen() (err error) { cmd := exec.Command("sh", "-c", "openssl genrsa -out ca.key 2048") out, err = cmd.CombinedOutput() if err != nil { - log.Printf("err:%s", err) + logger.Printf("err:%s", err) return } fmt.Println(string(out)) @@ -231,7 +231,7 @@ func Keygen() (err error) { cmd = exec.Command("sh", "-c", cmdStr) out, err = cmd.CombinedOutput() if err != nil { - log.Printf("err:%s", err) + logger.Printf("err:%s", err) return } fmt.Println(string(out)) @@ -250,7 +250,7 @@ func Keygen() (err error) { cmd := exec.Command("sh", "-c", "openssl genrsa -out "+name+".key 2048") out, err = cmd.CombinedOutput() if err != nil { - log.Printf("err:%s", err) + logger.Printf("err:%s", err) return } fmt.Println(string(out)) @@ -260,7 +260,7 @@ func Keygen() (err error) { cmd = exec.Command("sh", "-c", cmdStr) out, err = cmd.CombinedOutput() if err != nil { - log.Printf("err:%s", err) + logger.Printf("err:%s", err) return } fmt.Println(string(out)) @@ -270,7 +270,7 @@ func Keygen() (err error) { cmd = exec.Command("sh", "-c", cmdStr) out, err = cmd.CombinedOutput() if err != nil { - log.Printf("err:%s", err) + logger.Printf("err:%s", err) return } @@ -284,7 +284,7 @@ proxy keygen ca client0 30 //generate client0.crt client0.key and use ca.crt sig cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048") out, err = cmd.CombinedOutput() if err != nil { - log.Printf("err:%s", err) + logger.Printf("err:%s", err) return } fmt.Println(string(out)) @@ -293,7 +293,7 @@ proxy keygen ca client0 30 //generate client0.crt client0.key and use ca.crt sig cmd = exec.Command("sh", "-c", cmdStr) out, err = cmd.CombinedOutput() if err != nil { - log.Printf("err:%s", err) + logger.Printf("err:%s", err) return } fmt.Println(string(out)) diff --git a/utils/serve-channel.go b/utils/serve-channel.go index 1c1552d..80d6fcc 100644 --- a/utils/serve-channel.go +++ b/utils/serve-channel.go @@ -5,7 +5,7 @@ import ( "crypto/x509" "errors" "fmt" - "log" + logger "log" "net" "runtime/debug" "strconv" @@ -21,23 +21,26 @@ type ServerChannel struct { Listener *net.Listener UDPListener *net.UDPConn errAcceptHandler func(err error) + log *logger.Logger } -func NewServerChannel(ip string, port int) ServerChannel { +func NewServerChannel(ip string, port int, log *logger.Logger) ServerChannel { return ServerChannel{ ip: ip, port: port, + log: log, errAcceptHandler: func(err error) { log.Printf("accept error , ERR:%s", err) }, } } -func NewServerChannelHost(host string) ServerChannel { +func NewServerChannelHost(host string, log *logger.Logger) ServerChannel { h, port, _ := net.SplitHostPort(host) p, _ := strconv.Atoi(port) return ServerChannel{ ip: h, port: p, + log: log, errAcceptHandler: func(err error) { log.Printf("accept error , ERR:%s", err) }, @@ -52,7 +55,7 @@ func (sc *ServerChannel) ListenTls(certBytes, keyBytes, caCertBytes []byte, fn f go func() { defer func() { if e := recover(); e != nil { - log.Printf("ListenTls crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("ListenTls crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() for { @@ -62,7 +65,7 @@ func (sc *ServerChannel) ListenTls(certBytes, keyBytes, caCertBytes []byte, fn f go func() { defer func() { if e := recover(); e != nil { - log.Printf("tls connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("tls connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() fn(conn) @@ -112,7 +115,7 @@ func (sc *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) { go func() { defer func() { if e := recover(); e != nil { - log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() for { @@ -122,7 +125,7 @@ func (sc *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) { go func() { defer func() { if e := recover(); e != nil { - log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() fn(conn) @@ -144,7 +147,7 @@ func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *ne go func() { defer func() { if e := recover(); e != nil { - log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() for { @@ -155,7 +158,7 @@ func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *ne go func() { defer func() { if e := recover(); e != nil { - log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() fn(packet, addr, srcAddr) @@ -169,7 +172,7 @@ func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *ne } return } -func (sc *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net.Conn)) (err error) { +func (sc *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net.Conn), log *logger.Logger) (err error) { lis, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", sc.ip, sc.port), config.Block, *config.DataShard, *config.ParityShard) if err == nil { if err = lis.SetDSCP(*config.DSCP); err != nil { @@ -189,7 +192,7 @@ func (sc *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net go func() { defer func() { if e := recover(); e != nil { - log.Printf("ListenKCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("ListenKCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() for { @@ -199,7 +202,7 @@ func (sc *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net go func() { defer func() { if e := recover(); e != nil { - log.Printf("kcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + sc.log.Printf("kcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) } }() conn.SetStreamMode(true) diff --git a/utils/structs.go b/utils/structs.go index d15bb6f..69f415c 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -9,7 +9,7 @@ import ( "fmt" "io" "io/ioutil" - "log" + logger "log" "net" "net/url" "strings" @@ -30,6 +30,7 @@ type Checker struct { interval int64 timeout int isStop bool + log *logger.Logger } type CheckerItem struct { IsHTTPS bool @@ -45,12 +46,13 @@ type CheckerItem struct { //NewChecker args: //timeout : tcp timeout milliseconds ,connect to host //interval: recheck domain interval seconds -func NewChecker(timeout int, interval int64, blockedFile, directFile string) Checker { +func NewChecker(timeout int, interval int64, blockedFile, directFile string, log *logger.Logger) Checker { ch := Checker{ data: NewConcurrentMap(), interval: interval, timeout: timeout, isStop: false, + log: log, } ch.blockedMap = ch.loadMap(blockedFile) ch.directMap = ch.loadMap(directFile) @@ -72,7 +74,7 @@ func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) { if PathExists(f) { _contents, err := ioutil.ReadFile(f) if err != nil { - log.Printf("load file err:%s", err) + c.log.Printf("load file err:%s", err) return } for _, line := range strings.Split(string(_contents), "\n") { @@ -151,7 +153,7 @@ func (c *Checker) IsBlocked(address string) (blocked bool, failN, successN uint) func (c *Checker) domainIsInMap(address string, blockedMap bool) bool { u, err := url.Parse("http://" + address) if err != nil { - log.Printf("blocked check , url parse err:%s", err) + c.log.Printf("blocked check , url parse err:%s", err) return true } domainSlice := strings.Split(u.Hostname(), ".") @@ -189,12 +191,14 @@ type BasicAuth struct { authTimeout int authRetry int dns *DomainResolver + log *logger.Logger } -func NewBasicAuth(dns *DomainResolver) BasicAuth { +func NewBasicAuth(dns *DomainResolver, log *logger.Logger) BasicAuth { return BasicAuth{ data: NewConcurrentMap(), dns: dns, + log: log, } } func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) { @@ -247,7 +251,7 @@ func (ba *BasicAuth) Check(userpass string, ip, target string) (ok bool) { if err == nil { return true } - log.Printf("%s", err) + ba.log.Printf("%s", err) } return false } @@ -296,7 +300,7 @@ func (ba *BasicAuth) checkFromURL(userpass, ip, target string) (err error) { err = fmt.Errorf("auth fail from url %s,resonse code: %d, except: %d , %s , %s", URL, code, ba.authOkCode, ip, b) } if err != nil && tryCount < ba.authRetry { - log.Print(err) + ba.log.Print(err) time.Sleep(time.Second * 2) } tryCount++ @@ -322,13 +326,15 @@ type HTTPRequest struct { hostOrURL string isBasicAuth bool basicAuth *BasicAuth + log *logger.Logger } -func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth, header ...[]byte) (req HTTPRequest, err error) { +func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth, log *logger.Logger, header ...[]byte) (req HTTPRequest, err error) { buf := make([]byte, bufSize) n := 0 req = HTTPRequest{ conn: inConn, + log: log, } if header != nil && len(header) == 1 && len(header[0]) > 1 { buf = header[0] @@ -548,12 +554,14 @@ func (op *OutConn) Get() (conn net.Conn, err error) { type ConnManager struct { pool ConcurrentMap l *sync.Mutex + log *logger.Logger } -func NewConnManager() ConnManager { +func NewConnManager(log *logger.Logger) ConnManager { cm := ConnManager{ pool: NewConcurrentMap(), l: &sync.Mutex{}, + log: log, } return cm } @@ -570,7 +578,7 @@ func (cm *ConnManager) Add(key, ID string, conn *net.Conn) { (*v.(*net.Conn)).Close() } conns.Set(ID, conn) - log.Printf("%s conn added", key) + cm.log.Printf("%s conn added", key) return conns }) } @@ -581,7 +589,7 @@ func (cm *ConnManager) Remove(key string) { conns.IterCb(func(key string, v interface{}) { CloseConn(v.(*net.Conn)) }) - log.Printf("%s conns closed", key) + cm.log.Printf("%s conns closed", key) } cm.pool.Remove(key) } @@ -596,7 +604,7 @@ func (cm *ConnManager) RemoveOne(key string, ID string) { (*v.(*net.Conn)).Close() conns.Remove(ID) cm.pool.Set(key, conns) - log.Printf("%s %s conn closed", key, ID) + cm.log.Printf("%s %s conn closed", key, ID) } } } @@ -652,6 +660,7 @@ type DomainResolver struct { ttl int dnsAddrress string data ConcurrentMap + log *logger.Logger } type DomainResolverItem struct { ip string @@ -659,12 +668,13 @@ type DomainResolverItem struct { expiredAt int64 } -func NewDomainResolver(dnsAddrress string, ttl int) DomainResolver { +func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver { return DomainResolver{ ttl: ttl, dnsAddrress: dnsAddrress, data: NewConcurrentMap(), + log: log, } } func (a *DomainResolver) MustResolve(address string) (ip string) { @@ -679,7 +689,7 @@ func (a *DomainResolver) Resolve(address string) (ip string, err error) { if port != "" { ip = net.JoinHostPort(ip, port) } - log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache) + a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache) //a.PrintData() }() if strings.Contains(domain, ":") {