From aff38118e577e1931064bb7be446bf9f8619d5c5 Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Wed, 20 Sep 2017 19:28:36 +0800 Subject: [PATCH] Signed-off-by: arraykeys@gmail.com --- config.go | 12 +- functions.go | 100 ++++--- main.go | 808 +++++++++++++++++++++++++++++++++------------------ structs.go | 165 +++++++++++ 4 files changed, 750 insertions(+), 335 deletions(-) diff --git a/config.go b/config.go index 289b1b8..def80ae 100755 --- a/config.go +++ b/config.go @@ -23,8 +23,12 @@ func initConfig() (err error) { pflag.BoolP("parent-tls", "X", false, "parent proxy is tls") pflag.BoolP("local-tls", "x", false, "local proxy is tls") + pflag.BoolP("parent-tcp", "W", false, "parent proxy is tcp") + pflag.BoolP("local-tcp", "w", true, "local proxy is tcp") + pflag.BoolP("parent-udp", "U", false, "parent is udp") + pflag.BoolP("local-udp", "u", false, "local proxy is udp") version := pflag.BoolP("version", "v", false, "show version") - pflag.BoolP("tcp", "C", false, "proxy on tcp") + pflag.BoolP("local-http", "z", false, "proxy on http") pflag.Bool("always", false, "always use parent proxy") pflag.Int("check-proxy-interval", 3, "check if proxy is okay every interval seconds") @@ -47,7 +51,11 @@ func initConfig() (err error) { cfg.BindPFlag("parent-tls", pflag.Lookup("parent-tls")) cfg.BindPFlag("local-tls", pflag.Lookup("local-tls")) - cfg.BindPFlag("tcp", pflag.Lookup("tcp")) + cfg.BindPFlag("parent-udp", pflag.Lookup("parent-udp")) + cfg.BindPFlag("local-udp", pflag.Lookup("local-udp")) + cfg.BindPFlag("parent-tcp", pflag.Lookup("parent-tcp")) + cfg.BindPFlag("local-tcp", pflag.Lookup("local-tcp")) + cfg.BindPFlag("local-http", pflag.Lookup("local-http")) cfg.BindPFlag("always", pflag.Lookup("always")) cfg.BindPFlag("check-proxy-interval", pflag.Lookup("check-proxy-interval")) cfg.BindPFlag("port", pflag.Lookup("port")) diff --git a/functions.go b/functions.go index c1f2afd..7c20a6d 100755 --- a/functions.go +++ b/functions.go @@ -10,9 +10,11 @@ import ( "net" "net/http" "os" + "os/signal" "runtime/debug" "strconv" "strings" + "syscall" "time" ) @@ -186,6 +188,7 @@ func HTTPGet(URL string, timeout int) (err error) { } func initOutPool(isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) { + var err error outPool, err = NewConnPool(poolConfig{ IsActive: func(conn interface{}) bool { return true }, Release: func(conn interface{}) { @@ -241,57 +244,66 @@ func initPoolDeamon(isTLS bool, certBytes, keyBytes []byte, address string, time } }() } -func getURL(header []byte, host string) (URL string, err error) { - if !strings.HasPrefix(host, "/") { - return host, nil +func IsBasicAuth() bool { + return cfg.GetString("auth-file") != "" || len(cfg.GetStringSlice("auth")) > 0 +} +func InitBasicAuth() (err error) { + basicAuth = NewBasicAuth() + if cfg.GetString("auth-file") != "" { + n, err := basicAuth.AddFromFile(cfg.GetString("auth-file")) + if err != nil { + return fmt.Errorf("auth-file ERR:%s", err) + } + log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total()) } - _host, err := getHeader("host", header) + if len(cfg.GetStringSlice("auth")) > 0 { + n := basicAuth.Add(cfg.GetStringSlice("auth")) + log.Printf("auth data added %d, total:%d", n, basicAuth.Total()) + } + return +} +func clean() { + //block main() + signalChan := make(chan os.Signal, 1) + cleanupDone := make(chan bool) + signal.Notify(signalChan, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT) + go func() { + for _ = range signalChan { + if outPool != nil { + fmt.Println("\nReceived an interrupt, stopping services...") + outPool.ReleaseAll() + //time.Sleep(time.Second * 10) + // fmt.Println("done") + } + cleanupDone <- true + } + }() + <-cleanupDone +} +func HTTPProxyDecoder(inConn *net.Conn) (useProxy bool, request *HTTPRequest, err error) { + var req HTTPRequest + req, err = NewHTTPRequest(inConn, 4096) if err != nil { + //log.Printf("NewHTTPRequest ERR:%s", err) return } - URL = fmt.Sprintf("http://%s%s", _host, host) + useProxy = false + if checker.data != nil { + useProxy, _, _ = checker.IsBlocked(req.Host) + } + request = &req return } -func getHeader(key string, headData []byte) (val string, err error) { - key = strings.ToUpper(key) - lines := strings.Split(string(headData), "\r\n") - for _, line := range lines { - line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2) - if len(line) == 2 { - k := strings.ToUpper(strings.Trim(line[0], " ")) - v := strings.Trim(line[1], " ") - if key == k { - val = v - return - } - } - } - err = fmt.Errorf("can not find HOST header") - return -} -func hostIsNoPort(host string) bool { - //host: [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234 - if strings.HasPrefix(host, "[") { - return strings.HasSuffix(host, "]") - } - return strings.Index(host, ":") == -1 -} -func fixHost(host string) string { - if !strings.HasPrefix(host, "[") && len(strings.Split(host, ":")) > 2 { - if strings.HasSuffix(host, ":80") { - return fmt.Sprintf("[%s]:80", host[:strings.LastIndex(host, ":")]) - } - if strings.HasSuffix(host, ":443") { - return fmt.Sprintf("[%s]:443", host[:strings.LastIndex(host, ":")]) - } - } - return host -} -type sockaddr struct { - family uint16 - data [14]byte -} +// type sockaddr struct { +// family uint16 +// data [14]byte +// } // const SO_ORIGINAL_DST = 80 diff --git a/main.go b/main.go index 1d05b3a..6e503dc 100644 --- a/main.go +++ b/main.go @@ -1,46 +1,180 @@ package main import ( - "bytes" - "encoding/base64" "fmt" "io" "io/ioutil" "log" "net" - "net/url" - "os" "os/exec" - "os/signal" "runtime/debug" - "strings" - "syscall" "time" ) const APP_VERSION = "2.2" var ( - checker Checker - proxyIsTls bool - localIsTls bool - proxyAddr string - isTCP bool - connTimeout int - certBytes []byte - keyBytes []byte - err error - outPool ConnPool - basicAuth BasicAuth - httpAuthorization bool + checker Checker + certBytes []byte + keyBytes []byte + outPool ConnPool + basicAuth BasicAuth ) func init() { - err = initConfig() + err := initConfig() + if err != nil { + log.Fatalf("err : %s", err) + } + //Init + err = Init() + if err != nil { + log.Fatalf("err : %s", err) + } + isLocalHTTP := cfg.GetBool("local-http") + isTLS := cfg.GetBool("local-tls") || cfg.GetBool("parent-tls") + isTCP := isLocalHTTP || isTLS || cfg.GetBool("local-tcp") || cfg.GetBool("parent-tcp") + + //InitTCP + if isTCP { + err = InitTCP() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitTLS + if isTLS { + err = InitTLS() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitUDP + if cfg.GetBool("local-udp") || cfg.GetBool("parent-udp") { + err = InitUDP() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitLocal + err = InitLocal() + if err != nil { + log.Fatalf("err : %s", err) + } + //InitLocalTCP + if cfg.GetBool("local-tcp") || cfg.GetBool("local-tls") || isLocalHTTP { + err = InitLocalTCP() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitLocalTLS + if cfg.GetBool("local-tls") { + err = InitLocalTLS() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitLocalHTTP + if isLocalHTTP { + err = InitLocalHTTP() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitLocalUDP + if cfg.GetBool("local-udp") { + err = InitLocalUDP() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitParent + if cfg.GetString("parent") != "" { + err = InitParent() + if err != nil { + log.Fatalf("err : %s", err) + } + //InitParentTCP + if cfg.GetBool("parent-tcp") || cfg.GetBool("parent-tls") { + err = InitParentTCP() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitParentTLS + if cfg.GetBool("parent-tls") { + err = InitParentTLS() + } + if err != nil { + log.Fatalf("err : %s", err) + } + //InitParentUDP + if cfg.GetBool("parent-udp") { + err = InitParentUDP() + } + if err != nil { + log.Fatalf("err : %s", err) + } + } +} +func Init() (err error) { + return +} +func InitTCP() (err error) { + return +} +func InitTLS() (err error) { + certBytes, err = ioutil.ReadFile(cfg.GetString("cert")) if err != nil { log.Printf("err : %s", err) return } + keyBytes, err = ioutil.ReadFile(cfg.GetString("key")) + if err != nil { + log.Printf("err : %s", err) + return + } + return +} +func InitUDP() (err error) { + return +} +func InitLocal() (err error) { + return +} +func InitLocalTCP() (err error) { + return +} +func InitLocalTLS() (err error) { + + return +} +func InitLocalHTTP() (err error) { + err = InitBasicAuth() + if err != nil { + return + } + return +} + +func InitLocalUDP() (err error) { + return +} +func InitParent() (err error) { + initOutPool(cfg.GetBool("parent-tls"), certBytes, keyBytes, cfg.GetString("parent"), cfg.GetInt("tcp-timeout"), cfg.GetInt("pool-size"), cfg.GetInt("pool-size")*2) + checker = NewChecker(cfg.GetInt("check-timeout"), int64(cfg.GetInt("check-interval")), cfg.GetString("blocked"), cfg.GetString("direct")) + log.Printf("use parent proxy : %s, udp : %v, tcp : %v, tls: %v", cfg.GetString("parent"), cfg.GetBool("parent-udp"), cfg.GetBool("parent-tcp"), cfg.GetBool("parent-tls")) + return +} +func InitParentTCP() (err error) { + return +} +func InitParentTLS() (err error) { + return +} +func InitParentUDP() (err error) { + return } func main() { //catch panic error @@ -50,284 +184,84 @@ func main() { log.Printf("err : %s,\ntrace:%s", e, string(debug.Stack())) } }() - //define command line args - proxyIsTls = cfg.GetBool("parent-tls") - localIsTls = cfg.GetBool("local-tls") - proxyAddr = cfg.GetString("parent") - bindIP := cfg.GetString("ip") - bindPort := cfg.GetInt("port") - timeout := cfg.GetInt("check-timeout") - connTimeout = cfg.GetInt("tcp-timeout") - interval := cfg.GetInt("check-interval") - certFile := cfg.GetString("cert") - keyFile := cfg.GetString("key") - blockedFile := cfg.GetString("blocked") - directFile := cfg.GetString("direct") - isTCP = cfg.GetBool("tcp") - poolInitSize := cfg.GetInt("pool-size") + sc := NewServerChannel(cfg.GetString("ip"), cfg.GetInt("port")) + if cfg.GetBool("local-tls") { + LocalTLSServer(&sc) + } else if cfg.GetBool("local-tcp") { + LocalTCPServer(&sc) + } else if cfg.GetBool("local-udp") { + LocalUDPServer(&sc) + } + log.Printf("proxy on %s , udp: %v, tcp: %v, tls: %v ,http: %v", (*sc.Listener).Addr(), cfg.GetBool("local-udp"), cfg.GetBool("local-tcp"), cfg.GetBool("local-tls"), cfg.GetBool("local-http")) - //check args required - if proxyIsTls && proxyAddr == "" { - log.Fatalf("parent proxy address required") - } + clean() +} - //check tls cert&key file - if certFile == "" { - certFile = "proxy.crt" - } - if keyFile == "" { - keyFile = "proxy.key" - } - if proxyIsTls || localIsTls { - certBytes, err = ioutil.ReadFile(certFile) +func CheckTCPDeocder(inConn *net.Conn) (useProxy bool, address string, req *HTTPRequest, err error) { + if cfg.GetBool("local-http") { + useProxy, req, err = HTTPProxyDecoder(inConn) if err != nil { - log.Printf("err : %s", err) + if err != io.EOF { + log.Printf("http proxy decode error , ERR:%s", err) + } return } - keyBytes, err = ioutil.ReadFile(keyFile) + address = req.Host + } else { + address = cfg.GetString("parent") + } + + return +} +func LocalTCPServer(sc *ServerChannel) { + (*sc).ListenTCP(func(inConn net.Conn) { + userProxy, address, req, err := CheckTCPDeocder(&inConn) if err != nil { - log.Printf("err : %s", err) + log.Printf("%s", err) return } - } - //init tls info string - var proxyIsTlsStr string - var localIsTlsStr string - protocolStr := "tcp" - if !isTCP { - protocolStr = "http(s)" - } - if proxyIsTls { - proxyIsTlsStr = "tls " - } - if localIsTls { - localIsTlsStr = "tls " - } - //init checker and pool if needed - if proxyAddr != "" { - if !isTCP && !cfg.GetBool("always") { - checker = NewChecker(timeout, int64(interval), blockedFile, directFile) - } - log.Printf("use %sparent %s proxy : %s", proxyIsTlsStr, protocolStr, proxyAddr) - initOutPool(proxyIsTls, certBytes, keyBytes, proxyAddr, connTimeout, poolInitSize, poolInitSize*2) - } else if isTCP { - log.Printf("tcp proxy need parent") + TCPOutBridge(&inConn, userProxy, address, req) + }) +} + +func LocalTLSServer(sc *ServerChannel) { + certBytes, err := ioutil.ReadFile(cfg.GetString("cert")) + if err != nil { + log.Fatalf("err : %s", err) return } - //init basic auth only in http mode - if !isTCP { - basicAuth = NewBasicAuth() - if cfg.GetString("auth-file") != "" { - httpAuthorization = true - n, err := basicAuth.AddFromFile(cfg.GetString("auth-file")) - if err != nil { - log.Fatalf("auth-file:%s", err) - } - log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total()) - } - if len(cfg.GetStringSlice("auth")) > 0 { - httpAuthorization = true - n := basicAuth.Add(cfg.GetStringSlice("auth")) - log.Printf("auth data added %d, total:%d", n, basicAuth.Total()) - } + keyBytes, err := ioutil.ReadFile(cfg.GetString("key")) + if err != nil { + log.Fatalf("err : %s", err) + return } + (*sc).ListenTls(certBytes, keyBytes, func(inConn net.Conn) { + userProxy, address, req, err := CheckTCPDeocder(&inConn) + if err != nil { + return + } + TCPOutBridge(&inConn, userProxy, address, req) + }) +} - //listen - sc := NewServerChannel(bindIP, bindPort) +func LocalUDPServer(sc *ServerChannel) { + (*sc).ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) { + + }) +} + +func TCPOutBridge(inConn *net.Conn, userProxy bool, address string, req *HTTPRequest) { + var outConn net.Conn + var _outConn interface{} var err error - if localIsTls { - err = sc.ListenTls(certBytes, keyBytes, connHandler) - } else { - err = sc.ListenTCP(connHandler) - } - //listen fail - if err != nil { - log.Fatalf("ERR:%s", err) - } else { - log.Printf("%s %sproxy on %s", protocolStr, localIsTlsStr, (*sc.Listener).Addr()) - } - //block main() - signalChan := make(chan os.Signal, 1) - cleanupDone := make(chan bool) - signal.Notify(signalChan, - os.Interrupt, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT) - go func() { - for _ = range signalChan { - if outPool != nil { - fmt.Println("\nReceived an interrupt, stopping services...") - outPool.ReleaseAll() - //time.Sleep(time.Second * 10) - // fmt.Println("done") - } - cleanupDone <- true - } - }() - <-cleanupDone -} -func connHandler(inConn net.Conn) { - defer func() { - err := recover() - if err != nil { - log.Printf("connHandler crashed,err:%s\nstack:%s", err, string(debug.Stack())) - closeConn(&inConn) - } - }() - if isTCP { - tcpHandler(&inConn) - } else { - httpHandler(&inConn) - } -} - -func tcpHandler(inConn *net.Conn) { - var outConn net.Conn - var _outConn interface{} - _outConn, err = outPool.Get() - if err != nil { - log.Printf("connect to %s , err:%s", proxyAddr, err) - closeConn(inConn) - return - } - outConn = _outConn.(net.Conn) - inAddr := (*inConn).RemoteAddr().String() - outAddr := outConn.RemoteAddr().String() - IoBind((*inConn), outConn, func(err error) { - log.Printf("conn %s - %s released", inAddr, outAddr) - closeConn(inConn) - closeConn(&outConn) - }, func(n int, d bool) {}, 0) - log.Printf("conn %s - %s connected", inAddr, outAddr) -} -func httpHandler(inConn *net.Conn) { - var b [4096]byte - var n int - n, err = (*inConn).Read(b[:]) - if err != nil { - if err != io.EOF { - log.Printf("read err:%s", err) - } - closeConn(inConn) - return - } - var method, host, address string - index := bytes.IndexByte(b[:], '\n') - if index == -1 { - log.Printf("data err:%s", string(b[:n])[:50]) - closeConn(inConn) - return - } - - fmt.Sscanf(string(b[:index]), "%s%s", &method, &host) - if method == "" || host == "" { - log.Printf("data err:%s", string(b[:n])[:50]) - closeConn(inConn) - return - } - isHTTPS := method == "CONNECT" - - //http basic auth,only http - if !isHTTPS { - if httpAuthorization { - //log.Printf("request :%s", string(b[:n])) - authorization, err := getHeader("Authorization", b[:n]) - if err != nil { - fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized") - closeConn(inConn) - return - } - //log.Printf("Authorization:%s", authorization) - basic := strings.Fields(authorization) - if len(basic) != 2 { - log.Printf("authorization data error,ERR:%s", authorization) - closeConn(inConn) - return - } - user, err := base64.StdEncoding.DecodeString(basic[1]) - if err != nil { - log.Printf("authorization data parse error,ERR:%s", err) - closeConn(inConn) - return - } - authOk := basicAuth.Check(string(user)) - //log.Printf("auth %s,%v", string(user), authOk) - if !authOk { - fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized") - closeConn(inConn) - return - } - } - } - - var bytes []byte - if isHTTPS { //https访问 - // [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234 - address = fixHost(host) - if hostIsNoPort(host) { //host不带端口, 默认443 - address = address + ":443" - } - } else { //http访问 - hostPortURL, err := url.Parse(host) - if err != nil { - log.Printf("url.Parse %s ERR:%s", host, err) - closeConn(inConn) - return - } - _host := fixHost(hostPortURL.Host) - address = _host - if hostIsNoPort(_host) { //host不带端口, 默认80 - address = _host + ":80" - } - if _host != hostPortURL.Host { - bytes = []byte(strings.Replace(string(b[:n]), hostPortURL.Host, _host, 1)) - host = strings.Replace(host, hostPortURL.Host, _host, 1) - } - } - //get url , reslut host is the full url - host, err = getURL(b[:n], host) - // log.Printf("body:%s", string(b[:n])) - // log.Printf("%s:%s", method, host) - if err != nil { - log.Printf("header data err:%s", err) - closeConn(inConn) - return - } - - useProxy := false - if proxyAddr != "" { - if cfg.GetBool("always") { - useProxy = true - } else { - if isHTTPS { - checker.Add(address, true, method, "", nil) - } else { - if bytes != nil { - checker.Add(address, false, method, host, bytes) - } else { - checker.Add(address, false, method, host, b[:n]) - } - } - useProxy, _, _ = checker.IsBlocked(address) - } - // var failN, successN uint - // useProxy, failN, successN = checker.IsBlocked(address) - //log.Printf("use proxy ? %s : %v ,fail:%d, success:%d", address, useProxy, failN, successN) - //log.Printf("use proxy ? %s : %v", address, useProxy) - } - - var outConn net.Conn - var _outConn interface{} - if useProxy { + if userProxy { _outConn, err = outPool.Get() if err == nil { outConn = _outConn.(net.Conn) } } else { - outConn, err = ConnectHost(address, connTimeout) + outConn, err = ConnectHost(address, cfg.GetInt("tcp-timeout")) } if err != nil { log.Printf("connect to %s , err:%s", address, err) @@ -336,20 +270,16 @@ func httpHandler(inConn *net.Conn) { } inAddr := (*inConn).RemoteAddr().String() outAddr := outConn.RemoteAddr().String() + //log.Printf("%s use proxy %v",address, userProxy) - if isHTTPS { - if useProxy { - outConn.Write(b[:n]) + if req != nil { + if req.IsHTTPS() && !userProxy { + req.HTTPSReply() } else { - fmt.Fprint(*inConn, "HTTP/1.1 200 Connection established\r\n\r\n") - } - } else { - if bytes != nil { - outConn.Write(bytes) - } else { - outConn.Write(b[:n]) + outConn.Write(req.headBuf) } } + IoBind(*inConn, outConn, func(err error) { log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address) closeConn(inConn) @@ -357,6 +287,306 @@ func httpHandler(inConn *net.Conn) { }, func(n int, d bool) {}, 0) log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address) } +func UDPOutBridge() { + +} + +// func DoUDP() { +// if cfg.GetBool("local-udp") { + +// } else { + +// } +// } + +// //DoTCP contains tcp && http +// func DoTCP() { +// //define command line args +// proxyIsTls = cfg.GetBool("parent-tls") +// localIsTls = cfg.GetBool("local-tls") +// proxyAddr = cfg.GetString("parent") +// bindIP := cfg.GetString("ip") +// bindPort := cfg.GetInt("port") +// timeout := cfg.GetInt("check-timeout") +// connTimeout = cfg.GetInt("tcp-timeout") +// interval := cfg.GetInt("check-interval") +// certFile := cfg.GetString("cert") +// keyFile := cfg.GetString("key") +// blockedFile := cfg.GetString("blocked") +// directFile := cfg.GetString("direct") + +// isTCP = cfg.GetBool("tcp") +// poolInitSize := cfg.GetInt("pool-size") + +// //check args required +// if proxyIsTls && proxyAddr == "" { +// log.Fatalf("parent proxy address required") +// } + +// //check tls cert&key file +// if certFile == "" { +// certFile = "proxy.crt" +// } +// if keyFile == "" { +// keyFile = "proxy.key" +// } +// if proxyIsTls || localIsTls { +// certBytes, err = ioutil.ReadFile(certFile) +// if err != nil { +// log.Printf("err : %s", err) +// return +// } +// keyBytes, err = ioutil.ReadFile(keyFile) +// if err != nil { +// log.Printf("err : %s", err) +// return +// } +// } +// //init tls info string +// var proxyIsTlsStr string +// var localIsTlsStr string +// protocolStr := "tcp" +// if !isTCP { +// protocolStr = "http(s)" +// } +// if proxyIsTls { +// proxyIsTlsStr = "tls " +// } +// if localIsTls { +// localIsTlsStr = "tls " +// } +// //init checker and pool if needed +// if proxyAddr != "" { +// if !isTCP && !cfg.GetBool("always") { +// checker = NewChecker(timeout, int64(interval), blockedFile, directFile) +// } +// log.Printf("use %sparent %s proxy : %s", proxyIsTlsStr, protocolStr, proxyAddr) +// initOutPool(proxyIsTls, certBytes, keyBytes, proxyAddr, connTimeout, poolInitSize, poolInitSize*2) +// } else if isTCP { +// log.Printf("tcp proxy need parent") +// return +// } +// //init basic auth only in http mode +// if !isTCP { +// basicAuth = NewBasicAuth() +// if cfg.GetString("auth-file") != "" { +// httpAuthorization = true +// n, err := basicAuth.AddFromFile(cfg.GetString("auth-file")) +// if err != nil { +// log.Fatalf("auth-file:%s", err) +// } +// log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total()) +// } +// if len(cfg.GetStringSlice("auth")) > 0 { +// httpAuthorization = true +// n := basicAuth.Add(cfg.GetStringSlice("auth")) +// log.Printf("auth data added %d, total:%d", n, basicAuth.Total()) +// } +// } + +// //listen +// sc := NewServerChannel(bindIP, bindPort) +// var err error +// if localIsTls { +// err = sc.ListenTls(certBytes, keyBytes, connHandler) +// } else { +// err = sc.ListenTCP(connHandler) +// } +// //listen fail +// if err != nil { +// log.Fatalf("ERR:%s", err) +// } else { +// log.Printf("%s %sproxy on %s", protocolStr, localIsTlsStr, (*sc.Listener).Addr()) +// } +// } +// func connHandler(inConn net.Conn) { +// defer func() { +// err := recover() +// if err != nil { +// log.Printf("connHandler crashed,err:%s\nstack:%s", err, string(debug.Stack())) +// closeConn(&inConn) +// } +// }() +// if isTCP { +// tcpHandler(&inConn) +// } else { +// httpHandler(&inConn) +// } +// } + +// func tcpHandler(inConn *net.Conn) { +// var outConn net.Conn +// var _outConn interface{} +// _outConn, err = outPool.Get() +// if err != nil { +// log.Printf("connect to %s , err:%s", proxyAddr, err) +// closeConn(inConn) +// return +// } +// outConn = _outConn.(net.Conn) +// inAddr := (*inConn).RemoteAddr().String() +// outAddr := outConn.RemoteAddr().String() +// IoBind((*inConn), outConn, func(err error) { +// log.Printf("conn %s - %s released", inAddr, outAddr) +// closeConn(inConn) +// closeConn(&outConn) +// }, func(n int, d bool) {}, 0) +// log.Printf("conn %s - %s connected", inAddr, outAddr) +// } +// func httpHandler(inConn *net.Conn) { +// var b [4096]byte +// var n int +// n, err = (*inConn).Read(b[:]) +// if err != nil { +// if err != io.EOF { +// log.Printf("read err:%s", err) +// } +// closeConn(inConn) +// return +// } +// var method, host, address string +// index := bytes.IndexByte(b[:], '\n') +// if index == -1 { +// log.Printf("data err:%s", string(b[:n])[:50]) +// closeConn(inConn) +// return +// } + +// fmt.Sscanf(string(b[:index]), "%s%s", &method, &host) +// if method == "" || host == "" { +// log.Printf("data err:%s", string(b[:n])[:50]) +// closeConn(inConn) +// return +// } +// isHTTPS := method == "CONNECT" + +// //http basic auth,only http +// if !isHTTPS { +// if httpAuthorization { +// //log.Printf("request :%s", string(b[:n])) +// authorization, err := getHeader("Authorization", b[:n]) +// if err != nil { +// fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized") +// closeConn(inConn) +// return +// } +// //log.Printf("Authorization:%s", authorization) +// basic := strings.Fields(authorization) +// if len(basic) != 2 { +// log.Printf("authorization data error,ERR:%s", authorization) +// closeConn(inConn) +// return +// } +// user, err := base64.StdEncoding.DecodeString(basic[1]) +// if err != nil { +// log.Printf("authorization data parse error,ERR:%s", err) +// closeConn(inConn) +// return +// } +// authOk := basicAuth.Check(string(user)) +// //log.Printf("auth %s,%v", string(user), authOk) +// if !authOk { +// fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized") +// closeConn(inConn) +// return +// } +// } +// } + +// var bytes []byte +// if isHTTPS { //https访问 +// // [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234 +// address = fixHost(host) +// if hostIsNoPort(host) { //host不带端口, 默认443 +// address = address + ":443" +// } +// } else { //http访问 +// hostPortURL, err := url.Parse(host) +// if err != nil { +// log.Printf("url.Parse %s ERR:%s", host, err) +// closeConn(inConn) +// return +// } +// _host := fixHost(hostPortURL.Host) +// address = _host +// if hostIsNoPort(_host) { //host不带端口, 默认80 +// address = _host + ":80" +// } +// if _host != hostPortURL.Host { +// bytes = []byte(strings.Replace(string(b[:n]), hostPortURL.Host, _host, 1)) +// host = strings.Replace(host, hostPortURL.Host, _host, 1) +// } +// } +// //get url , reslut host is the full url +// host, err = getURL(b[:n], host) +// // log.Printf("body:%s", string(b[:n])) +// // log.Printf("%s:%s", method, host) +// if err != nil { +// log.Printf("header data err:%s", err) +// closeConn(inConn) +// return +// } + +// useProxy := false +// if proxyAddr != "" { +// if cfg.GetBool("always") { +// useProxy = true +// } else { +// if isHTTPS { +// checker.Add(address, true, method, "", nil) +// } else { +// if bytes != nil { +// checker.Add(address, false, method, host, bytes) +// } else { +// checker.Add(address, false, method, host, b[:n]) +// } +// } +// useProxy, _, _ = checker.IsBlocked(address) +// } +// // var failN, successN uint +// // useProxy, failN, successN = checker.IsBlocked(address) +// //log.Printf("use proxy ? %s : %v ,fail:%d, success:%d", address, useProxy, failN, successN) +// //log.Printf("use proxy ? %s : %v", address, useProxy) +// } + +// var outConn net.Conn +// var _outConn interface{} +// if useProxy { +// _outConn, err = outPool.Get() +// if err == nil { +// outConn = _outConn.(net.Conn) +// } +// } else { +// outConn, err = ConnectHost(address, connTimeout) +// } +// if err != nil { +// log.Printf("connect to %s , err:%s", address, err) +// closeConn(inConn) +// return +// } +// inAddr := (*inConn).RemoteAddr().String() +// outAddr := outConn.RemoteAddr().String() + +// if isHTTPS { +// if useProxy { +// outConn.Write(b[:n]) +// } else { +// fmt.Fprint(*inConn, "HTTP/1.1 200 Connection established\r\n\r\n") +// } +// } else { +// if bytes != nil { +// outConn.Write(bytes) +// } else { +// outConn.Write(b[:n]) +// } +// } +// IoBind(*inConn, outConn, func(err error) { +// log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address) +// closeConn(inConn) +// closeConn(&outConn) +// }, func(n int, d bool) {}, 0) +// log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address) +// } func closeConn(conn *net.Conn) { if *conn != nil { (*conn).SetDeadline(time.Now().Add(time.Millisecond)) diff --git a/structs.go b/structs.go index 81e1c51..dd298c4 100644 --- a/structs.go +++ b/structs.go @@ -1,6 +1,10 @@ package main import ( + "bytes" + "encoding/base64" + "fmt" + "io" "io/ioutil" "log" "net" @@ -217,3 +221,164 @@ func (ba *BasicAuth) Total() (n int) { n = ba.data.Count() return } + +type HTTPRequest struct { + headBuf []byte + conn *net.Conn + Host string + Method string + URL string + hostOrURL string +} + +func NewHTTPRequest(inConn *net.Conn, bufSize int) (req HTTPRequest, err error) { + buf := make([]byte, bufSize) + len := 0 + req = HTTPRequest{ + conn: inConn, + } + len, err = (*inConn).Read(buf[:]) + if err != nil { + if err != io.EOF { + err = fmt.Errorf("http decoder read err:%s", err) + } + closeConn(inConn) + return + } + req.headBuf = buf[:len] + index := bytes.IndexByte(req.headBuf, '\n') + if index == -1 { + err = fmt.Errorf("http decoder data line err:%s", string(req.headBuf)[:50]) + closeConn(inConn) + return + } + fmt.Sscanf(string(req.headBuf[:index]), "%s%s", &req.Method, &req.hostOrURL) + if req.Method == "" || req.hostOrURL == "" { + err = fmt.Errorf("http decoder data err:%s", string(req.headBuf)[:50]) + closeConn(inConn) + return + } + req.Method = strings.ToUpper(req.Method) + log.Printf("%s:%s", req.Method, req.hostOrURL) + + if req.IsHTTPS() { + err = req.HTTPS() + } else { + err = req.HTTP() + } + return +} +func (req *HTTPRequest) HTTP() (err error) { + if IsBasicAuth() { + err = req.BasicAuth() + if err != nil { + return + } + } + req.URL, err = req.getHTTPURL() + if err == nil { + u, _ := url.Parse(req.URL) + req.Host = u.Host + req.addPortIfNot() + } + return +} +func (req *HTTPRequest) HTTPS() (err error) { + req.Host = req.hostOrURL + req.addPortIfNot() + //_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n") + return +} +func (req *HTTPRequest) HTTPSReply() (err error) { + _, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n") + return +} +func (req *HTTPRequest) IsHTTPS() bool { + return req.Method == "CONNECT" +} + +func (req *HTTPRequest) BasicAuth() (err error) { + + //log.Printf("request :%s", string(b[:n])) + authorization, err := req.getHeader("Authorization") + if err != nil { + fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized") + closeConn(req.conn) + return + } + //log.Printf("Authorization:%s", authorization) + basic := strings.Fields(authorization) + if len(basic) != 2 { + err = fmt.Errorf("authorization data error,ERR:%s", authorization) + closeConn(req.conn) + return + } + user, err := base64.StdEncoding.DecodeString(basic[1]) + if err != nil { + err = fmt.Errorf("authorization data parse error,ERR:%s", err) + closeConn(req.conn) + return + } + authOk := basicAuth.Check(string(user)) + //log.Printf("auth %s,%v", string(user), authOk) + if !authOk { + fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized") + closeConn(req.conn) + err = fmt.Errorf("basic auth fail") + return + } + return +} +func (req *HTTPRequest) getHTTPURL() (URL string, err error) { + if !strings.HasPrefix(req.hostOrURL, "/") { + return req.hostOrURL, nil + } + _host, err := req.getHeader("host") + if err != nil { + return + } + URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL) + return +} +func (req *HTTPRequest) getHeader(key string) (val string, err error) { + key = strings.ToUpper(key) + lines := strings.Split(string(req.headBuf), "\r\n") + for _, line := range lines { + line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2) + if len(line) == 2 { + k := strings.ToUpper(strings.Trim(line[0], " ")) + v := strings.Trim(line[1], " ") + if key == k { + val = v + return + } + } + } + err = fmt.Errorf("can not find HOST header") + return +} +func (req *HTTPRequest) addPortIfNot() (newHost string) { + //newHost = req.Host + port := "80" + if req.IsHTTPS() { + port = "443" + } + if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) { + //newHost = req.Host + ":" + port + //req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1)) + req.Host = req.Host + ":" + port + } + return +} + +// func (req *HTTPRequest) fixHost(host string) string { +// if !strings.HasPrefix(host, "[") && len(strings.Split(host, ":")) > 2 { +// if strings.HasSuffix(host, ":80") { +// return fmt.Sprintf("[%s]:80", host[:strings.LastIndex(host, ":")]) +// } +// if strings.HasSuffix(host, ":443") { +// return fmt.Sprintf("[%s]:443", host[:strings.LastIndex(host, ":")]) +// } +// } +// return host +// }