Signed-off-by: arraykeys@gmail.com <arraykeys@gmail.com>

This commit is contained in:
arraykeys@gmail.com
2017-09-20 19:28:36 +08:00
parent 09242575d8
commit aff38118e5
4 changed files with 750 additions and 335 deletions

View File

@ -23,8 +23,12 @@ func initConfig() (err error) {
pflag.BoolP("parent-tls", "X", false, "parent proxy is tls") pflag.BoolP("parent-tls", "X", false, "parent proxy is tls")
pflag.BoolP("local-tls", "x", false, "local proxy is tls") pflag.BoolP("local-tls", "x", false, "local proxy is tls")
pflag.BoolP("parent-tcp", "W", false, "parent proxy is tcp")
pflag.BoolP("local-tcp", "w", true, "local proxy is tcp")
pflag.BoolP("parent-udp", "U", false, "parent is udp")
pflag.BoolP("local-udp", "u", false, "local proxy is udp")
version := pflag.BoolP("version", "v", false, "show version") version := pflag.BoolP("version", "v", false, "show version")
pflag.BoolP("tcp", "C", false, "proxy on tcp") pflag.BoolP("local-http", "z", false, "proxy on http")
pflag.Bool("always", false, "always use parent proxy") pflag.Bool("always", false, "always use parent proxy")
pflag.Int("check-proxy-interval", 3, "check if proxy is okay every interval seconds") pflag.Int("check-proxy-interval", 3, "check if proxy is okay every interval seconds")
@ -47,7 +51,11 @@ func initConfig() (err error) {
cfg.BindPFlag("parent-tls", pflag.Lookup("parent-tls")) cfg.BindPFlag("parent-tls", pflag.Lookup("parent-tls"))
cfg.BindPFlag("local-tls", pflag.Lookup("local-tls")) cfg.BindPFlag("local-tls", pflag.Lookup("local-tls"))
cfg.BindPFlag("tcp", pflag.Lookup("tcp")) cfg.BindPFlag("parent-udp", pflag.Lookup("parent-udp"))
cfg.BindPFlag("local-udp", pflag.Lookup("local-udp"))
cfg.BindPFlag("parent-tcp", pflag.Lookup("parent-tcp"))
cfg.BindPFlag("local-tcp", pflag.Lookup("local-tcp"))
cfg.BindPFlag("local-http", pflag.Lookup("local-http"))
cfg.BindPFlag("always", pflag.Lookup("always")) cfg.BindPFlag("always", pflag.Lookup("always"))
cfg.BindPFlag("check-proxy-interval", pflag.Lookup("check-proxy-interval")) cfg.BindPFlag("check-proxy-interval", pflag.Lookup("check-proxy-interval"))
cfg.BindPFlag("port", pflag.Lookup("port")) cfg.BindPFlag("port", pflag.Lookup("port"))

View File

@ -10,9 +10,11 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"os/signal"
"runtime/debug" "runtime/debug"
"strconv" "strconv"
"strings" "strings"
"syscall"
"time" "time"
) )
@ -186,6 +188,7 @@ func HTTPGet(URL string, timeout int) (err error) {
} }
func initOutPool(isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) { func initOutPool(isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) {
var err error
outPool, err = NewConnPool(poolConfig{ outPool, err = NewConnPool(poolConfig{
IsActive: func(conn interface{}) bool { return true }, IsActive: func(conn interface{}) bool { return true },
Release: func(conn interface{}) { Release: func(conn interface{}) {
@ -241,57 +244,66 @@ func initPoolDeamon(isTLS bool, certBytes, keyBytes []byte, address string, time
} }
}() }()
} }
func getURL(header []byte, host string) (URL string, err error) { func IsBasicAuth() bool {
if !strings.HasPrefix(host, "/") { return cfg.GetString("auth-file") != "" || len(cfg.GetStringSlice("auth")) > 0
return host, nil
} }
_host, err := getHeader("host", header) func InitBasicAuth() (err error) {
basicAuth = NewBasicAuth()
if cfg.GetString("auth-file") != "" {
n, err := basicAuth.AddFromFile(cfg.GetString("auth-file"))
if err != nil { if err != nil {
return fmt.Errorf("auth-file ERR:%s", err)
}
log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total())
}
if len(cfg.GetStringSlice("auth")) > 0 {
n := basicAuth.Add(cfg.GetStringSlice("auth"))
log.Printf("auth data added %d, total:%d", n, basicAuth.Total())
}
return return
} }
URL = fmt.Sprintf("http://%s%s", _host, host) func clean() {
//block main()
signalChan := make(chan os.Signal, 1)
cleanupDone := make(chan bool)
signal.Notify(signalChan,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
go func() {
for _ = range signalChan {
if outPool != nil {
fmt.Println("\nReceived an interrupt, stopping services...")
outPool.ReleaseAll()
//time.Sleep(time.Second * 10)
// fmt.Println("done")
}
cleanupDone <- true
}
}()
<-cleanupDone
}
func HTTPProxyDecoder(inConn *net.Conn) (useProxy bool, request *HTTPRequest, err error) {
var req HTTPRequest
req, err = NewHTTPRequest(inConn, 4096)
if err != nil {
//log.Printf("NewHTTPRequest ERR:%s", err)
return return
} }
func getHeader(key string, headData []byte) (val string, err error) { useProxy = false
key = strings.ToUpper(key) if checker.data != nil {
lines := strings.Split(string(headData), "\r\n") useProxy, _, _ = checker.IsBlocked(req.Host)
for _, line := range lines { }
line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2) request = &req
if len(line) == 2 {
k := strings.ToUpper(strings.Trim(line[0], " "))
v := strings.Trim(line[1], " ")
if key == k {
val = v
return return
} }
}
}
err = fmt.Errorf("can not find HOST header")
return
}
func hostIsNoPort(host string) bool {
//host: [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234
if strings.HasPrefix(host, "[") {
return strings.HasSuffix(host, "]")
}
return strings.Index(host, ":") == -1
}
func fixHost(host string) string {
if !strings.HasPrefix(host, "[") && len(strings.Split(host, ":")) > 2 {
if strings.HasSuffix(host, ":80") {
return fmt.Sprintf("[%s]:80", host[:strings.LastIndex(host, ":")])
}
if strings.HasSuffix(host, ":443") {
return fmt.Sprintf("[%s]:443", host[:strings.LastIndex(host, ":")])
}
}
return host
}
type sockaddr struct { // type sockaddr struct {
family uint16 // family uint16
data [14]byte // data [14]byte
} // }
// const SO_ORIGINAL_DST = 80 // const SO_ORIGINAL_DST = 80

798
main.go
View File

@ -1,20 +1,13 @@
package main package main
import ( import (
"bytes"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"net/url"
"os"
"os/exec" "os/exec"
"os/signal"
"runtime/debug" "runtime/debug"
"strings"
"syscall"
"time" "time"
) )
@ -22,25 +15,166 @@ const APP_VERSION = "2.2"
var ( var (
checker Checker checker Checker
proxyIsTls bool
localIsTls bool
proxyAddr string
isTCP bool
connTimeout int
certBytes []byte certBytes []byte
keyBytes []byte keyBytes []byte
err error
outPool ConnPool outPool ConnPool
basicAuth BasicAuth basicAuth BasicAuth
httpAuthorization bool
) )
func init() { func init() {
err = initConfig() err := initConfig()
if err != nil {
log.Fatalf("err : %s", err)
}
//Init
err = Init()
if err != nil {
log.Fatalf("err : %s", err)
}
isLocalHTTP := cfg.GetBool("local-http")
isTLS := cfg.GetBool("local-tls") || cfg.GetBool("parent-tls")
isTCP := isLocalHTTP || isTLS || cfg.GetBool("local-tcp") || cfg.GetBool("parent-tcp")
//InitTCP
if isTCP {
err = InitTCP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitTLS
if isTLS {
err = InitTLS()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitUDP
if cfg.GetBool("local-udp") || cfg.GetBool("parent-udp") {
err = InitUDP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocal
err = InitLocal()
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalTCP
if cfg.GetBool("local-tcp") || cfg.GetBool("local-tls") || isLocalHTTP {
err = InitLocalTCP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalTLS
if cfg.GetBool("local-tls") {
err = InitLocalTLS()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalHTTP
if isLocalHTTP {
err = InitLocalHTTP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalUDP
if cfg.GetBool("local-udp") {
err = InitLocalUDP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParent
if cfg.GetString("parent") != "" {
err = InitParent()
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParentTCP
if cfg.GetBool("parent-tcp") || cfg.GetBool("parent-tls") {
err = InitParentTCP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParentTLS
if cfg.GetBool("parent-tls") {
err = InitParentTLS()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParentUDP
if cfg.GetBool("parent-udp") {
err = InitParentUDP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
}
}
func Init() (err error) {
return
}
func InitTCP() (err error) {
return
}
func InitTLS() (err error) {
certBytes, err = ioutil.ReadFile(cfg.GetString("cert"))
if err != nil { if err != nil {
log.Printf("err : %s", err) log.Printf("err : %s", err)
return return
} }
keyBytes, err = ioutil.ReadFile(cfg.GetString("key"))
if err != nil {
log.Printf("err : %s", err)
return
}
return
}
func InitUDP() (err error) {
return
}
func InitLocal() (err error) {
return
}
func InitLocalTCP() (err error) {
return
}
func InitLocalTLS() (err error) {
return
}
func InitLocalHTTP() (err error) {
err = InitBasicAuth()
if err != nil {
return
}
return
}
func InitLocalUDP() (err error) {
return
}
func InitParent() (err error) {
initOutPool(cfg.GetBool("parent-tls"), certBytes, keyBytes, cfg.GetString("parent"), cfg.GetInt("tcp-timeout"), cfg.GetInt("pool-size"), cfg.GetInt("pool-size")*2)
checker = NewChecker(cfg.GetInt("check-timeout"), int64(cfg.GetInt("check-interval")), cfg.GetString("blocked"), cfg.GetString("direct"))
log.Printf("use parent proxy : %s, udp : %v, tcp : %v, tls: %v", cfg.GetString("parent"), cfg.GetBool("parent-udp"), cfg.GetBool("parent-tcp"), cfg.GetBool("parent-tls"))
return
}
func InitParentTCP() (err error) {
return
}
func InitParentTLS() (err error) {
return
}
func InitParentUDP() (err error) {
return
} }
func main() { func main() {
//catch panic error //catch panic error
@ -50,284 +184,84 @@ func main() {
log.Printf("err : %s,\ntrace:%s", e, string(debug.Stack())) log.Printf("err : %s,\ntrace:%s", e, string(debug.Stack()))
} }
}() }()
//define command line args
proxyIsTls = cfg.GetBool("parent-tls")
localIsTls = cfg.GetBool("local-tls")
proxyAddr = cfg.GetString("parent")
bindIP := cfg.GetString("ip")
bindPort := cfg.GetInt("port")
timeout := cfg.GetInt("check-timeout")
connTimeout = cfg.GetInt("tcp-timeout")
interval := cfg.GetInt("check-interval")
certFile := cfg.GetString("cert")
keyFile := cfg.GetString("key")
blockedFile := cfg.GetString("blocked")
directFile := cfg.GetString("direct")
isTCP = cfg.GetBool("tcp") sc := NewServerChannel(cfg.GetString("ip"), cfg.GetInt("port"))
poolInitSize := cfg.GetInt("pool-size") if cfg.GetBool("local-tls") {
LocalTLSServer(&sc)
} else if cfg.GetBool("local-tcp") {
LocalTCPServer(&sc)
} else if cfg.GetBool("local-udp") {
LocalUDPServer(&sc)
}
log.Printf("proxy on %s , udp: %v, tcp: %v, tls: %v ,http: %v", (*sc.Listener).Addr(), cfg.GetBool("local-udp"), cfg.GetBool("local-tcp"), cfg.GetBool("local-tls"), cfg.GetBool("local-http"))
//check args required clean()
if proxyIsTls && proxyAddr == "" {
log.Fatalf("parent proxy address required")
} }
//check tls cert&key file func CheckTCPDeocder(inConn *net.Conn) (useProxy bool, address string, req *HTTPRequest, err error) {
if certFile == "" { if cfg.GetBool("local-http") {
certFile = "proxy.crt" useProxy, req, err = HTTPProxyDecoder(inConn)
}
if keyFile == "" {
keyFile = "proxy.key"
}
if proxyIsTls || localIsTls {
certBytes, err = ioutil.ReadFile(certFile)
if err != nil {
log.Printf("err : %s", err)
return
}
keyBytes, err = ioutil.ReadFile(keyFile)
if err != nil {
log.Printf("err : %s", err)
return
}
}
//init tls info string
var proxyIsTlsStr string
var localIsTlsStr string
protocolStr := "tcp"
if !isTCP {
protocolStr = "http(s)"
}
if proxyIsTls {
proxyIsTlsStr = "tls "
}
if localIsTls {
localIsTlsStr = "tls "
}
//init checker and pool if needed
if proxyAddr != "" {
if !isTCP && !cfg.GetBool("always") {
checker = NewChecker(timeout, int64(interval), blockedFile, directFile)
}
log.Printf("use %sparent %s proxy : %s", proxyIsTlsStr, protocolStr, proxyAddr)
initOutPool(proxyIsTls, certBytes, keyBytes, proxyAddr, connTimeout, poolInitSize, poolInitSize*2)
} else if isTCP {
log.Printf("tcp proxy need parent")
return
}
//init basic auth only in http mode
if !isTCP {
basicAuth = NewBasicAuth()
if cfg.GetString("auth-file") != "" {
httpAuthorization = true
n, err := basicAuth.AddFromFile(cfg.GetString("auth-file"))
if err != nil {
log.Fatalf("auth-file:%s", err)
}
log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total())
}
if len(cfg.GetStringSlice("auth")) > 0 {
httpAuthorization = true
n := basicAuth.Add(cfg.GetStringSlice("auth"))
log.Printf("auth data added %d, total:%d", n, basicAuth.Total())
}
}
//listen
sc := NewServerChannel(bindIP, bindPort)
var err error
if localIsTls {
err = sc.ListenTls(certBytes, keyBytes, connHandler)
} else {
err = sc.ListenTCP(connHandler)
}
//listen fail
if err != nil {
log.Fatalf("ERR:%s", err)
} else {
log.Printf("%s %sproxy on %s", protocolStr, localIsTlsStr, (*sc.Listener).Addr())
}
//block main()
signalChan := make(chan os.Signal, 1)
cleanupDone := make(chan bool)
signal.Notify(signalChan,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
go func() {
for _ = range signalChan {
if outPool != nil {
fmt.Println("\nReceived an interrupt, stopping services...")
outPool.ReleaseAll()
//time.Sleep(time.Second * 10)
// fmt.Println("done")
}
cleanupDone <- true
}
}()
<-cleanupDone
}
func connHandler(inConn net.Conn) {
defer func() {
err := recover()
if err != nil {
log.Printf("connHandler crashed,err:%s\nstack:%s", err, string(debug.Stack()))
closeConn(&inConn)
}
}()
if isTCP {
tcpHandler(&inConn)
} else {
httpHandler(&inConn)
}
}
func tcpHandler(inConn *net.Conn) {
var outConn net.Conn
var _outConn interface{}
_outConn, err = outPool.Get()
if err != nil {
log.Printf("connect to %s , err:%s", proxyAddr, err)
closeConn(inConn)
return
}
outConn = _outConn.(net.Conn)
inAddr := (*inConn).RemoteAddr().String()
outAddr := outConn.RemoteAddr().String()
IoBind((*inConn), outConn, func(err error) {
log.Printf("conn %s - %s released", inAddr, outAddr)
closeConn(inConn)
closeConn(&outConn)
}, func(n int, d bool) {}, 0)
log.Printf("conn %s - %s connected", inAddr, outAddr)
}
func httpHandler(inConn *net.Conn) {
var b [4096]byte
var n int
n, err = (*inConn).Read(b[:])
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Printf("read err:%s", err) log.Printf("http proxy decode error , ERR:%s", err)
} }
closeConn(inConn)
return return
} }
var method, host, address string address = req.Host
index := bytes.IndexByte(b[:], '\n')
if index == -1 {
log.Printf("data err:%s", string(b[:n])[:50])
closeConn(inConn)
return
}
fmt.Sscanf(string(b[:index]), "%s%s", &method, &host)
if method == "" || host == "" {
log.Printf("data err:%s", string(b[:n])[:50])
closeConn(inConn)
return
}
isHTTPS := method == "CONNECT"
//http basic auth,only http
if !isHTTPS {
if httpAuthorization {
//log.Printf("request :%s", string(b[:n]))
authorization, err := getHeader("Authorization", b[:n])
if err != nil {
fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
closeConn(inConn)
return
}
//log.Printf("Authorization:%s", authorization)
basic := strings.Fields(authorization)
if len(basic) != 2 {
log.Printf("authorization data error,ERR:%s", authorization)
closeConn(inConn)
return
}
user, err := base64.StdEncoding.DecodeString(basic[1])
if err != nil {
log.Printf("authorization data parse error,ERR:%s", err)
closeConn(inConn)
return
}
authOk := basicAuth.Check(string(user))
//log.Printf("auth %s,%v", string(user), authOk)
if !authOk {
fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
closeConn(inConn)
return
}
}
}
var bytes []byte
if isHTTPS { //https访问
// [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234
address = fixHost(host)
if hostIsNoPort(host) { //host不带端口 默认443
address = address + ":443"
}
} else { //http访问
hostPortURL, err := url.Parse(host)
if err != nil {
log.Printf("url.Parse %s ERR:%s", host, err)
closeConn(inConn)
return
}
_host := fixHost(hostPortURL.Host)
address = _host
if hostIsNoPort(_host) { //host不带端口 默认80
address = _host + ":80"
}
if _host != hostPortURL.Host {
bytes = []byte(strings.Replace(string(b[:n]), hostPortURL.Host, _host, 1))
host = strings.Replace(host, hostPortURL.Host, _host, 1)
}
}
//get url , reslut host is the full url
host, err = getURL(b[:n], host)
// log.Printf("body:%s", string(b[:n]))
// log.Printf("%s:%s", method, host)
if err != nil {
log.Printf("header data err:%s", err)
closeConn(inConn)
return
}
useProxy := false
if proxyAddr != "" {
if cfg.GetBool("always") {
useProxy = true
} else { } else {
if isHTTPS { address = cfg.GetString("parent")
checker.Add(address, true, method, "", nil)
} else {
if bytes != nil {
checker.Add(address, false, method, host, bytes)
} else {
checker.Add(address, false, method, host, b[:n])
}
}
useProxy, _, _ = checker.IsBlocked(address)
}
// var failN, successN uint
// useProxy, failN, successN = checker.IsBlocked(address)
//log.Printf("use proxy ? %s : %v ,fail:%d, success:%d", address, useProxy, failN, successN)
//log.Printf("use proxy ? %s : %v", address, useProxy)
} }
return
}
func LocalTCPServer(sc *ServerChannel) {
(*sc).ListenTCP(func(inConn net.Conn) {
userProxy, address, req, err := CheckTCPDeocder(&inConn)
if err != nil {
log.Printf("%s", err)
return
}
TCPOutBridge(&inConn, userProxy, address, req)
})
}
func LocalTLSServer(sc *ServerChannel) {
certBytes, err := ioutil.ReadFile(cfg.GetString("cert"))
if err != nil {
log.Fatalf("err : %s", err)
return
}
keyBytes, err := ioutil.ReadFile(cfg.GetString("key"))
if err != nil {
log.Fatalf("err : %s", err)
return
}
(*sc).ListenTls(certBytes, keyBytes, func(inConn net.Conn) {
userProxy, address, req, err := CheckTCPDeocder(&inConn)
if err != nil {
return
}
TCPOutBridge(&inConn, userProxy, address, req)
})
}
func LocalUDPServer(sc *ServerChannel) {
(*sc).ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) {
})
}
func TCPOutBridge(inConn *net.Conn, userProxy bool, address string, req *HTTPRequest) {
var outConn net.Conn var outConn net.Conn
var _outConn interface{} var _outConn interface{}
if useProxy { var err error
if userProxy {
_outConn, err = outPool.Get() _outConn, err = outPool.Get()
if err == nil { if err == nil {
outConn = _outConn.(net.Conn) outConn = _outConn.(net.Conn)
} }
} else { } else {
outConn, err = ConnectHost(address, connTimeout) outConn, err = ConnectHost(address, cfg.GetInt("tcp-timeout"))
} }
if err != nil { if err != nil {
log.Printf("connect to %s , err:%s", address, err) log.Printf("connect to %s , err:%s", address, err)
@ -336,20 +270,16 @@ func httpHandler(inConn *net.Conn) {
} }
inAddr := (*inConn).RemoteAddr().String() inAddr := (*inConn).RemoteAddr().String()
outAddr := outConn.RemoteAddr().String() outAddr := outConn.RemoteAddr().String()
//log.Printf("%s use proxy %v",address, userProxy)
if isHTTPS { if req != nil {
if useProxy { if req.IsHTTPS() && !userProxy {
outConn.Write(b[:n]) req.HTTPSReply()
} else { } else {
fmt.Fprint(*inConn, "HTTP/1.1 200 Connection established\r\n\r\n") outConn.Write(req.headBuf)
}
} else {
if bytes != nil {
outConn.Write(bytes)
} else {
outConn.Write(b[:n])
} }
} }
IoBind(*inConn, outConn, func(err error) { IoBind(*inConn, outConn, func(err error) {
log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address) log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address)
closeConn(inConn) closeConn(inConn)
@ -357,6 +287,306 @@ func httpHandler(inConn *net.Conn) {
}, func(n int, d bool) {}, 0) }, func(n int, d bool) {}, 0)
log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address) log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address)
} }
func UDPOutBridge() {
}
// func DoUDP() {
// if cfg.GetBool("local-udp") {
// } else {
// }
// }
// //DoTCP contains tcp && http
// func DoTCP() {
// //define command line args
// proxyIsTls = cfg.GetBool("parent-tls")
// localIsTls = cfg.GetBool("local-tls")
// proxyAddr = cfg.GetString("parent")
// bindIP := cfg.GetString("ip")
// bindPort := cfg.GetInt("port")
// timeout := cfg.GetInt("check-timeout")
// connTimeout = cfg.GetInt("tcp-timeout")
// interval := cfg.GetInt("check-interval")
// certFile := cfg.GetString("cert")
// keyFile := cfg.GetString("key")
// blockedFile := cfg.GetString("blocked")
// directFile := cfg.GetString("direct")
// isTCP = cfg.GetBool("tcp")
// poolInitSize := cfg.GetInt("pool-size")
// //check args required
// if proxyIsTls && proxyAddr == "" {
// log.Fatalf("parent proxy address required")
// }
// //check tls cert&key file
// if certFile == "" {
// certFile = "proxy.crt"
// }
// if keyFile == "" {
// keyFile = "proxy.key"
// }
// if proxyIsTls || localIsTls {
// certBytes, err = ioutil.ReadFile(certFile)
// if err != nil {
// log.Printf("err : %s", err)
// return
// }
// keyBytes, err = ioutil.ReadFile(keyFile)
// if err != nil {
// log.Printf("err : %s", err)
// return
// }
// }
// //init tls info string
// var proxyIsTlsStr string
// var localIsTlsStr string
// protocolStr := "tcp"
// if !isTCP {
// protocolStr = "http(s)"
// }
// if proxyIsTls {
// proxyIsTlsStr = "tls "
// }
// if localIsTls {
// localIsTlsStr = "tls "
// }
// //init checker and pool if needed
// if proxyAddr != "" {
// if !isTCP && !cfg.GetBool("always") {
// checker = NewChecker(timeout, int64(interval), blockedFile, directFile)
// }
// log.Printf("use %sparent %s proxy : %s", proxyIsTlsStr, protocolStr, proxyAddr)
// initOutPool(proxyIsTls, certBytes, keyBytes, proxyAddr, connTimeout, poolInitSize, poolInitSize*2)
// } else if isTCP {
// log.Printf("tcp proxy need parent")
// return
// }
// //init basic auth only in http mode
// if !isTCP {
// basicAuth = NewBasicAuth()
// if cfg.GetString("auth-file") != "" {
// httpAuthorization = true
// n, err := basicAuth.AddFromFile(cfg.GetString("auth-file"))
// if err != nil {
// log.Fatalf("auth-file:%s", err)
// }
// log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total())
// }
// if len(cfg.GetStringSlice("auth")) > 0 {
// httpAuthorization = true
// n := basicAuth.Add(cfg.GetStringSlice("auth"))
// log.Printf("auth data added %d, total:%d", n, basicAuth.Total())
// }
// }
// //listen
// sc := NewServerChannel(bindIP, bindPort)
// var err error
// if localIsTls {
// err = sc.ListenTls(certBytes, keyBytes, connHandler)
// } else {
// err = sc.ListenTCP(connHandler)
// }
// //listen fail
// if err != nil {
// log.Fatalf("ERR:%s", err)
// } else {
// log.Printf("%s %sproxy on %s", protocolStr, localIsTlsStr, (*sc.Listener).Addr())
// }
// }
// func connHandler(inConn net.Conn) {
// defer func() {
// err := recover()
// if err != nil {
// log.Printf("connHandler crashed,err:%s\nstack:%s", err, string(debug.Stack()))
// closeConn(&inConn)
// }
// }()
// if isTCP {
// tcpHandler(&inConn)
// } else {
// httpHandler(&inConn)
// }
// }
// func tcpHandler(inConn *net.Conn) {
// var outConn net.Conn
// var _outConn interface{}
// _outConn, err = outPool.Get()
// if err != nil {
// log.Printf("connect to %s , err:%s", proxyAddr, err)
// closeConn(inConn)
// return
// }
// outConn = _outConn.(net.Conn)
// inAddr := (*inConn).RemoteAddr().String()
// outAddr := outConn.RemoteAddr().String()
// IoBind((*inConn), outConn, func(err error) {
// log.Printf("conn %s - %s released", inAddr, outAddr)
// closeConn(inConn)
// closeConn(&outConn)
// }, func(n int, d bool) {}, 0)
// log.Printf("conn %s - %s connected", inAddr, outAddr)
// }
// func httpHandler(inConn *net.Conn) {
// var b [4096]byte
// var n int
// n, err = (*inConn).Read(b[:])
// if err != nil {
// if err != io.EOF {
// log.Printf("read err:%s", err)
// }
// closeConn(inConn)
// return
// }
// var method, host, address string
// index := bytes.IndexByte(b[:], '\n')
// if index == -1 {
// log.Printf("data err:%s", string(b[:n])[:50])
// closeConn(inConn)
// return
// }
// fmt.Sscanf(string(b[:index]), "%s%s", &method, &host)
// if method == "" || host == "" {
// log.Printf("data err:%s", string(b[:n])[:50])
// closeConn(inConn)
// return
// }
// isHTTPS := method == "CONNECT"
// //http basic auth,only http
// if !isHTTPS {
// if httpAuthorization {
// //log.Printf("request :%s", string(b[:n]))
// authorization, err := getHeader("Authorization", b[:n])
// if err != nil {
// fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
// closeConn(inConn)
// return
// }
// //log.Printf("Authorization:%s", authorization)
// basic := strings.Fields(authorization)
// if len(basic) != 2 {
// log.Printf("authorization data error,ERR:%s", authorization)
// closeConn(inConn)
// return
// }
// user, err := base64.StdEncoding.DecodeString(basic[1])
// if err != nil {
// log.Printf("authorization data parse error,ERR:%s", err)
// closeConn(inConn)
// return
// }
// authOk := basicAuth.Check(string(user))
// //log.Printf("auth %s,%v", string(user), authOk)
// if !authOk {
// fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
// closeConn(inConn)
// return
// }
// }
// }
// var bytes []byte
// if isHTTPS { //https访问
// // [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234
// address = fixHost(host)
// if hostIsNoPort(host) { //host不带端口 默认443
// address = address + ":443"
// }
// } else { //http访问
// hostPortURL, err := url.Parse(host)
// if err != nil {
// log.Printf("url.Parse %s ERR:%s", host, err)
// closeConn(inConn)
// return
// }
// _host := fixHost(hostPortURL.Host)
// address = _host
// if hostIsNoPort(_host) { //host不带端口 默认80
// address = _host + ":80"
// }
// if _host != hostPortURL.Host {
// bytes = []byte(strings.Replace(string(b[:n]), hostPortURL.Host, _host, 1))
// host = strings.Replace(host, hostPortURL.Host, _host, 1)
// }
// }
// //get url , reslut host is the full url
// host, err = getURL(b[:n], host)
// // log.Printf("body:%s", string(b[:n]))
// // log.Printf("%s:%s", method, host)
// if err != nil {
// log.Printf("header data err:%s", err)
// closeConn(inConn)
// return
// }
// useProxy := false
// if proxyAddr != "" {
// if cfg.GetBool("always") {
// useProxy = true
// } else {
// if isHTTPS {
// checker.Add(address, true, method, "", nil)
// } else {
// if bytes != nil {
// checker.Add(address, false, method, host, bytes)
// } else {
// checker.Add(address, false, method, host, b[:n])
// }
// }
// useProxy, _, _ = checker.IsBlocked(address)
// }
// // var failN, successN uint
// // useProxy, failN, successN = checker.IsBlocked(address)
// //log.Printf("use proxy ? %s : %v ,fail:%d, success:%d", address, useProxy, failN, successN)
// //log.Printf("use proxy ? %s : %v", address, useProxy)
// }
// var outConn net.Conn
// var _outConn interface{}
// if useProxy {
// _outConn, err = outPool.Get()
// if err == nil {
// outConn = _outConn.(net.Conn)
// }
// } else {
// outConn, err = ConnectHost(address, connTimeout)
// }
// if err != nil {
// log.Printf("connect to %s , err:%s", address, err)
// closeConn(inConn)
// return
// }
// inAddr := (*inConn).RemoteAddr().String()
// outAddr := outConn.RemoteAddr().String()
// if isHTTPS {
// if useProxy {
// outConn.Write(b[:n])
// } else {
// fmt.Fprint(*inConn, "HTTP/1.1 200 Connection established\r\n\r\n")
// }
// } else {
// if bytes != nil {
// outConn.Write(bytes)
// } else {
// outConn.Write(b[:n])
// }
// }
// IoBind(*inConn, outConn, func(err error) {
// log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address)
// closeConn(inConn)
// closeConn(&outConn)
// }, func(n int, d bool) {}, 0)
// log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address)
// }
func closeConn(conn *net.Conn) { func closeConn(conn *net.Conn) {
if *conn != nil { if *conn != nil {
(*conn).SetDeadline(time.Now().Add(time.Millisecond)) (*conn).SetDeadline(time.Now().Add(time.Millisecond))

View File

@ -1,6 +1,10 @@
package main package main
import ( import (
"bytes"
"encoding/base64"
"fmt"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
@ -217,3 +221,164 @@ func (ba *BasicAuth) Total() (n int) {
n = ba.data.Count() n = ba.data.Count()
return return
} }
type HTTPRequest struct {
headBuf []byte
conn *net.Conn
Host string
Method string
URL string
hostOrURL string
}
func NewHTTPRequest(inConn *net.Conn, bufSize int) (req HTTPRequest, err error) {
buf := make([]byte, bufSize)
len := 0
req = HTTPRequest{
conn: inConn,
}
len, err = (*inConn).Read(buf[:])
if err != nil {
if err != io.EOF {
err = fmt.Errorf("http decoder read err:%s", err)
}
closeConn(inConn)
return
}
req.headBuf = buf[:len]
index := bytes.IndexByte(req.headBuf, '\n')
if index == -1 {
err = fmt.Errorf("http decoder data line err:%s", string(req.headBuf)[:50])
closeConn(inConn)
return
}
fmt.Sscanf(string(req.headBuf[:index]), "%s%s", &req.Method, &req.hostOrURL)
if req.Method == "" || req.hostOrURL == "" {
err = fmt.Errorf("http decoder data err:%s", string(req.headBuf)[:50])
closeConn(inConn)
return
}
req.Method = strings.ToUpper(req.Method)
log.Printf("%s:%s", req.Method, req.hostOrURL)
if req.IsHTTPS() {
err = req.HTTPS()
} else {
err = req.HTTP()
}
return
}
func (req *HTTPRequest) HTTP() (err error) {
if IsBasicAuth() {
err = req.BasicAuth()
if err != nil {
return
}
}
req.URL, err = req.getHTTPURL()
if err == nil {
u, _ := url.Parse(req.URL)
req.Host = u.Host
req.addPortIfNot()
}
return
}
func (req *HTTPRequest) HTTPS() (err error) {
req.Host = req.hostOrURL
req.addPortIfNot()
//_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
return
}
func (req *HTTPRequest) HTTPSReply() (err error) {
_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
return
}
func (req *HTTPRequest) IsHTTPS() bool {
return req.Method == "CONNECT"
}
func (req *HTTPRequest) BasicAuth() (err error) {
//log.Printf("request :%s", string(b[:n]))
authorization, err := req.getHeader("Authorization")
if err != nil {
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
closeConn(req.conn)
return
}
//log.Printf("Authorization:%s", authorization)
basic := strings.Fields(authorization)
if len(basic) != 2 {
err = fmt.Errorf("authorization data error,ERR:%s", authorization)
closeConn(req.conn)
return
}
user, err := base64.StdEncoding.DecodeString(basic[1])
if err != nil {
err = fmt.Errorf("authorization data parse error,ERR:%s", err)
closeConn(req.conn)
return
}
authOk := basicAuth.Check(string(user))
//log.Printf("auth %s,%v", string(user), authOk)
if !authOk {
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
closeConn(req.conn)
err = fmt.Errorf("basic auth fail")
return
}
return
}
func (req *HTTPRequest) getHTTPURL() (URL string, err error) {
if !strings.HasPrefix(req.hostOrURL, "/") {
return req.hostOrURL, nil
}
_host, err := req.getHeader("host")
if err != nil {
return
}
URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL)
return
}
func (req *HTTPRequest) getHeader(key string) (val string, err error) {
key = strings.ToUpper(key)
lines := strings.Split(string(req.headBuf), "\r\n")
for _, line := range lines {
line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2)
if len(line) == 2 {
k := strings.ToUpper(strings.Trim(line[0], " "))
v := strings.Trim(line[1], " ")
if key == k {
val = v
return
}
}
}
err = fmt.Errorf("can not find HOST header")
return
}
func (req *HTTPRequest) addPortIfNot() (newHost string) {
//newHost = req.Host
port := "80"
if req.IsHTTPS() {
port = "443"
}
if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) {
//newHost = req.Host + ":" + port
//req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1))
req.Host = req.Host + ":" + port
}
return
}
// func (req *HTTPRequest) fixHost(host string) string {
// if !strings.HasPrefix(host, "[") && len(strings.Split(host, ":")) > 2 {
// if strings.HasSuffix(host, ":80") {
// return fmt.Sprintf("[%s]:80", host[:strings.LastIndex(host, ":")])
// }
// if strings.HasSuffix(host, ":443") {
// return fmt.Sprintf("[%s]:443", host[:strings.LastIndex(host, ":")])
// }
// }
// return host
// }