diff --git a/config.go b/config.go index 07bcbf9..2e5cb0e 100755 --- a/config.go +++ b/config.go @@ -63,6 +63,10 @@ func initConfig() (err error) { httpArgs.KCPKey = http.Flag("kcp-key", "key for kcp encrypt/decrypt data").Short('B').Default("encrypt").String() httpArgs.KCPMethod = http.Flag("kcp-method", "kcp encrypt/decrypt method").Short('M').Default("3des").String() httpArgs.LocalIPS = http.Flag("local bind ips", "if your host behind a nat,set your public ip here avoid dead loop").Short('g').Strings() + httpArgs.AuthURL = http.Flag("auth-url", "http basic auth username and password will send to this url,response http code equal to 'auth-code' means ok,others means fail.").Default("").String() + httpArgs.AuthURLTimeout = http.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int() + httpArgs.AuthURLOkCode = http.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int() + httpArgs.AuthURLRetry = http.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("1").Int() //########tcp######### tcp := app.Command("tcp", "proxy on tcp mode") @@ -138,6 +142,10 @@ func initConfig() (err error) { socksArgs.KCPKey = socks.Flag("kcp-key", "key for kcp encrypt/decrypt data").Short('B').Default("encrypt").String() socksArgs.KCPMethod = socks.Flag("kcp-method", "kcp encrypt/decrypt method").Short('M').Default("3des").String() socksArgs.LocalIPS = socks.Flag("local bind ips", "if your host behind a nat,set your public ip here avoid dead loop").Short('g').Strings() + socksArgs.AuthURL = socks.Flag("auth-url", "auth username and password will send to this url,response http code equal to 'auth-code' means ok,others means fail.").Default("").String() + socksArgs.AuthURLTimeout = socks.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int() + socksArgs.AuthURLOkCode = socks.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int() + socksArgs.AuthURLRetry = socks.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("1").Int() //parse args serviceName := kingpin.MustParse(app.Parse(os.Args[1:])) diff --git a/services/args.go b/services/args.go index c448fc9..100d23c 100644 --- a/services/args.go +++ b/services/args.go @@ -80,6 +80,10 @@ type HTTPArgs struct { Direct *string AuthFile *string Auth *[]string + AuthURL *string + AuthURLOkCode *int + AuthURLTimeout *int + AuthURLRetry *int ParentType *string LocalType *string Timeout *int @@ -129,6 +133,10 @@ type SocksArgs struct { Direct *string AuthFile *string Auth *[]string + AuthURL *string + AuthURLOkCode *int + AuthURLTimeout *int + AuthURLRetry *int KCPMethod *string KCPKey *string UDPParent *string diff --git a/services/http.go b/services/http.go index 0a3fc3d..73b6cdc 100644 --- a/services/http.go +++ b/services/http.go @@ -146,7 +146,7 @@ func (s *HTTP) callback(inConn net.Conn) { req, err = utils.NewHTTPRequest(&inConn, 4096, s.IsBasicAuth(), &s.basicAuth) if err != nil { if err != io.EOF { - log.Printf("decoder error , form %s, ERR:%s", err, inConn.RemoteAddr()) + log.Printf("decoder error , from %s, ERR:%s", inConn.RemoteAddr(), err) } utils.CloseConn(&inConn) return @@ -322,6 +322,10 @@ func (s *HTTP) InitOutConnPool() { } func (s *HTTP) InitBasicAuth() (err error) { s.basicAuth = utils.NewBasicAuth() + if *s.cfg.AuthURL != "" { + s.basicAuth.SetAuthURL(*s.cfg.AuthURL, *s.cfg.AuthURLOkCode, *s.cfg.AuthURLTimeout, *s.cfg.AuthURLRetry) + log.Printf("auth from %s", *s.cfg.AuthURL) + } if *s.cfg.AuthFile != "" { var n = 0 n, err = s.basicAuth.AddFromFile(*s.cfg.AuthFile) @@ -338,7 +342,7 @@ func (s *HTTP) InitBasicAuth() (err error) { return } func (s *HTTP) IsBasicAuth() bool { - return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 + return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 || *s.cfg.AuthURL != "" } func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool { inIP, inPort, err := net.SplitHostPort(inLocalAddr) diff --git a/services/socks.go b/services/socks.go index 3ce80dd..3758d60 100644 --- a/services/socks.go +++ b/services/socks.go @@ -576,7 +576,7 @@ func (s *Socks) InitBasicAuth() (err error) { return } func (s *Socks) IsBasicAuth() bool { - return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 + return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 || *s.cfg.AuthURL != "" } func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool { inIP, inPort, err := net.SplitHostPort(inLocalAddr) diff --git a/utils/functions.go b/utils/functions.go index 9842dd6..6718002 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -428,6 +428,29 @@ func GetKCPBlock(method, key string) (block kcp.BlockCrypt) { } return } +func HttpGet(URL string, timeout int) (body []byte, code int, err error) { + var tr *http.Transport + var client *http.Client + conf := &tls.Config{ + InsecureSkipVerify: true, + } + if strings.Contains(URL, "https://") { + tr = &http.Transport{TLSClientConfig: conf} + client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr} + } else { + tr = &http.Transport{} + client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr} + } + defer tr.CloseIdleConnections() + resp, err := client.Get(URL) + if err != nil { + return + } + defer resp.Body.Close() + code = resp.StatusCode + body, err = ioutil.ReadAll(resp.Body) + return +} // type sockaddr struct { // family uint16 diff --git a/utils/structs.go b/utils/structs.go index ae7723e..22a8004 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -176,7 +176,11 @@ func (c *Checker) Add(address string, isHTTPS bool, method, URL string, data []b } type BasicAuth struct { - data ConcurrentMap + data ConcurrentMap + authURL string + authOkCode int + authTimeout int + authRetry int } func NewBasicAuth() BasicAuth { @@ -184,6 +188,12 @@ func NewBasicAuth() BasicAuth { data: NewConcurrentMap(), } } +func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) { + ba.authURL = URL + ba.authOkCode = code + ba.authTimeout = timeout + ba.authRetry = retry +} func (ba *BasicAuth) AddFromFile(file string) (n int, err error) { _content, err := ioutil.ReadFile(file) if err != nil { @@ -219,15 +229,65 @@ func (ba *BasicAuth) CheckUserPass(user, pass string) (ok bool) { } return } -func (ba *BasicAuth) Check(userpass string) (ok bool) { +func (ba *BasicAuth) Check(userpass string, ip string) (ok bool) { u := strings.Split(strings.Trim(userpass, " "), ":") if len(u) == 2 { if p, _ok := ba.data.Get(u[0]); _ok { return p.(string) == u[1] } + if ba.authURL != "" { + err := ba.checkFromURL(userpass, ip) + if err == nil { + return true + } + log.Printf("%s", err) + } + return false } return } +func (ba *BasicAuth) checkFromURL(userpass, ip string) (err error) { + u := strings.Split(strings.Trim(userpass, " "), ":") + if len(u) != 2 { + return + } + URL := ba.authURL + if strings.Contains(URL, "?") { + URL += "&" + } else { + URL += "?" + } + URL += fmt.Sprintf("user=%s&pass=%s&ip=%s", u[0], u[1], ip) + var code int + var tryCount = 0 + var body []byte + for tryCount <= ba.authRetry { + body, code, err = HttpGet(URL, ba.authTimeout) + if err == nil && code == ba.authOkCode { + break + } else if err != nil { + err = fmt.Errorf("auth fail from url %s,resonse err:%s , %s", URL, err, ip) + } else { + if len(body) > 0 { + err = fmt.Errorf(string(body[0:100])) + } else { + err = fmt.Errorf("token error") + } + err = fmt.Errorf("auth fail from url %s,resonse code: %d, except: %d , %s , %s", URL, code, ba.authOkCode, ip, string(body)) + } + if err != nil && tryCount < ba.authRetry { + log.Print(err) + time.Sleep(time.Second * 2) + } + tryCount++ + } + if err != nil { + return + } + //log.Printf("auth success from auth url, %s", ip) + return +} + func (ba *BasicAuth) Total() (n int) { n = ba.data.Count() return @@ -325,6 +385,14 @@ func (req *HTTPRequest) BasicAuth() (err error) { CloseConn(req.conn) return } + if authorization == "" { + authorization, err = req.getHeader("Proxy-Authorization") + if err != nil { + fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized") + CloseConn(req.conn) + return + } + } //log.Printf("Authorization:%s", authorization) basic := strings.Fields(authorization) if len(basic) != 2 { @@ -338,7 +406,8 @@ func (req *HTTPRequest) BasicAuth() (err error) { CloseConn(req.conn) return } - authOk := (*req.basicAuth).Check(string(user)) + addr := strings.Split((*req.conn).RemoteAddr().String(), ":") + authOk := (*req.basicAuth).Check(string(user), addr[0]) //log.Printf("auth %s,%v", string(user), authOk) if !authOk { fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized") @@ -362,6 +431,7 @@ func (req *HTTPRequest) getHTTPURL() (URL string, err error) { func (req *HTTPRequest) getHeader(key string) (val string, err error) { key = strings.ToUpper(key) lines := strings.Split(string(req.HeadBuf), "\r\n") + //log.Println(lines) for _, line := range lines { line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2) if len(line) == 2 { @@ -373,7 +443,6 @@ func (req *HTTPRequest) getHeader(key string) (val string, err error) { } } } - err = fmt.Errorf("can not find HOST header") return }