goproxy/main.go
2017-09-20 19:35:36 +08:00

615 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os/exec"
"runtime/debug"
"time"
)
const APP_VERSION = "2.2"
var (
checker Checker
certBytes []byte
keyBytes []byte
outPool ConnPool
basicAuth BasicAuth
)
func init() {
err := initConfig()
if err != nil {
log.Fatalf("err : %s", err)
}
//Init
err = Init()
if err != nil {
log.Fatalf("err : %s", err)
}
isLocalHTTP := cfg.GetBool("local-http")
isTLS := cfg.GetBool("local-tls") || cfg.GetBool("parent-tls")
isTCP := isLocalHTTP || isTLS || cfg.GetBool("local-tcp") || cfg.GetBool("parent-tcp")
//InitTCP
if isTCP {
err = InitTCP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitTLS
if isTLS {
err = InitTLS()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitUDP
if cfg.GetBool("local-udp") || cfg.GetBool("parent-udp") {
err = InitUDP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocal
err = InitLocal()
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalTCP
if cfg.GetBool("local-tcp") || cfg.GetBool("local-tls") || isLocalHTTP {
err = InitLocalTCP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalTLS
if cfg.GetBool("local-tls") {
err = InitLocalTLS()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalHTTP
if isLocalHTTP {
err = InitLocalHTTP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitLocalUDP
if cfg.GetBool("local-udp") {
err = InitLocalUDP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParent
if cfg.GetString("parent") != "" {
err = InitParent()
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParentTCP
if cfg.GetBool("parent-tcp") || cfg.GetBool("parent-tls") {
err = InitParentTCP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParentTLS
if cfg.GetBool("parent-tls") {
err = InitParentTLS()
}
if err != nil {
log.Fatalf("err : %s", err)
}
//InitParentUDP
if cfg.GetBool("parent-udp") {
err = InitParentUDP()
}
if err != nil {
log.Fatalf("err : %s", err)
}
}
}
func Init() (err error) {
return
}
func InitTCP() (err error) {
return
}
func InitTLS() (err error) {
certBytes, err = ioutil.ReadFile(cfg.GetString("cert"))
if err != nil {
log.Printf("err : %s", err)
return
}
keyBytes, err = ioutil.ReadFile(cfg.GetString("key"))
if err != nil {
log.Printf("err : %s", err)
return
}
return
}
func InitUDP() (err error) {
return
}
func InitLocal() (err error) {
return
}
func InitLocalTCP() (err error) {
return
}
func InitLocalTLS() (err error) {
return
}
func InitLocalHTTP() (err error) {
err = InitBasicAuth()
if err != nil {
return
}
return
}
func InitLocalUDP() (err error) {
return
}
func InitParent() (err error) {
initOutPool(cfg.GetBool("parent-tls"), certBytes, keyBytes, cfg.GetString("parent"), cfg.GetInt("tcp-timeout"), cfg.GetInt("pool-size"), cfg.GetInt("pool-size")*2)
checker = NewChecker(cfg.GetInt("check-timeout"), int64(cfg.GetInt("check-interval")), cfg.GetString("blocked"), cfg.GetString("direct"))
log.Printf("use parent proxy : %s, udp : %v, tcp : %v, tls: %v", cfg.GetString("parent"), cfg.GetBool("parent-udp"), cfg.GetBool("parent-tcp"), cfg.GetBool("parent-tls"))
return
}
func InitParentTCP() (err error) {
return
}
func InitParentTLS() (err error) {
return
}
func InitParentUDP() (err error) {
return
}
func main() {
//catch panic error
defer func() {
e := recover()
if e != nil {
log.Printf("err : %s,\ntrace:%s", e, string(debug.Stack()))
}
}()
sc := NewServerChannel(cfg.GetString("ip"), cfg.GetInt("port"))
if cfg.GetBool("local-tls") {
LocalTLSServer(&sc)
} else if cfg.GetBool("local-tcp") {
LocalTCPServer(&sc)
} else if cfg.GetBool("local-udp") {
LocalUDPServer(&sc)
}
log.Printf("proxy on %s , udp: %v, tcp: %v, tls: %v ,http: %v", (*sc.Listener).Addr(), cfg.GetBool("local-udp"), cfg.GetBool("local-tcp"), cfg.GetBool("local-tls"), cfg.GetBool("local-http"))
clean()
}
func CheckTCPDeocder(inConn *net.Conn) (useProxy bool, address string, req *HTTPRequest, err error) {
if cfg.GetBool("local-http") {
useProxy, req, err = HTTPProxyDecoder(inConn)
if err != nil {
if err != io.EOF {
log.Printf("http proxy decode error , ERR:%s", err)
}
return
}
address = req.Host
} else {
address = cfg.GetString("parent")
}
if cfg.GetBool("always") {
useProxy = true
}
return
}
func LocalTCPServer(sc *ServerChannel) {
(*sc).ListenTCP(func(inConn net.Conn) {
userProxy, address, req, err := CheckTCPDeocder(&inConn)
if err != nil {
log.Printf("%s", err)
return
}
TCPOutBridge(&inConn, userProxy, address, req)
})
}
func LocalTLSServer(sc *ServerChannel) {
certBytes, err := ioutil.ReadFile(cfg.GetString("cert"))
if err != nil {
log.Fatalf("err : %s", err)
return
}
keyBytes, err := ioutil.ReadFile(cfg.GetString("key"))
if err != nil {
log.Fatalf("err : %s", err)
return
}
(*sc).ListenTls(certBytes, keyBytes, func(inConn net.Conn) {
userProxy, address, req, err := CheckTCPDeocder(&inConn)
if err != nil {
return
}
TCPOutBridge(&inConn, userProxy, address, req)
})
}
func LocalUDPServer(sc *ServerChannel) {
(*sc).ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) {
})
}
func TCPOutBridge(inConn *net.Conn, userProxy bool, address string, req *HTTPRequest) {
var outConn net.Conn
var _outConn interface{}
var err error
if userProxy {
_outConn, err = outPool.Get()
if err == nil {
outConn = _outConn.(net.Conn)
}
} else {
outConn, err = ConnectHost(address, cfg.GetInt("tcp-timeout"))
}
if err != nil {
log.Printf("connect to %s , err:%s", address, err)
closeConn(inConn)
return
}
inAddr := (*inConn).RemoteAddr().String()
outAddr := outConn.RemoteAddr().String()
log.Printf("%s use proxy %v", address, userProxy)
if req != nil {
if req.IsHTTPS() && !userProxy {
req.HTTPSReply()
} else {
outConn.Write(req.headBuf)
}
}
IoBind(*inConn, outConn, func(err error) {
log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address)
closeConn(inConn)
closeConn(&outConn)
}, func(n int, d bool) {}, 0)
log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address)
}
func UDPOutBridge() {
}
// func DoUDP() {
// if cfg.GetBool("local-udp") {
// } else {
// }
// }
// //DoTCP contains tcp && http
// func DoTCP() {
// //define command line args
// proxyIsTls = cfg.GetBool("parent-tls")
// localIsTls = cfg.GetBool("local-tls")
// proxyAddr = cfg.GetString("parent")
// bindIP := cfg.GetString("ip")
// bindPort := cfg.GetInt("port")
// timeout := cfg.GetInt("check-timeout")
// connTimeout = cfg.GetInt("tcp-timeout")
// interval := cfg.GetInt("check-interval")
// certFile := cfg.GetString("cert")
// keyFile := cfg.GetString("key")
// blockedFile := cfg.GetString("blocked")
// directFile := cfg.GetString("direct")
// isTCP = cfg.GetBool("tcp")
// poolInitSize := cfg.GetInt("pool-size")
// //check args required
// if proxyIsTls && proxyAddr == "" {
// log.Fatalf("parent proxy address required")
// }
// //check tls cert&key file
// if certFile == "" {
// certFile = "proxy.crt"
// }
// if keyFile == "" {
// keyFile = "proxy.key"
// }
// if proxyIsTls || localIsTls {
// certBytes, err = ioutil.ReadFile(certFile)
// if err != nil {
// log.Printf("err : %s", err)
// return
// }
// keyBytes, err = ioutil.ReadFile(keyFile)
// if err != nil {
// log.Printf("err : %s", err)
// return
// }
// }
// //init tls info string
// var proxyIsTlsStr string
// var localIsTlsStr string
// protocolStr := "tcp"
// if !isTCP {
// protocolStr = "http(s)"
// }
// if proxyIsTls {
// proxyIsTlsStr = "tls "
// }
// if localIsTls {
// localIsTlsStr = "tls "
// }
// //init checker and pool if needed
// if proxyAddr != "" {
// if !isTCP && !cfg.GetBool("always") {
// checker = NewChecker(timeout, int64(interval), blockedFile, directFile)
// }
// log.Printf("use %sparent %s proxy : %s", proxyIsTlsStr, protocolStr, proxyAddr)
// initOutPool(proxyIsTls, certBytes, keyBytes, proxyAddr, connTimeout, poolInitSize, poolInitSize*2)
// } else if isTCP {
// log.Printf("tcp proxy need parent")
// return
// }
// //init basic auth only in http mode
// if !isTCP {
// basicAuth = NewBasicAuth()
// if cfg.GetString("auth-file") != "" {
// httpAuthorization = true
// n, err := basicAuth.AddFromFile(cfg.GetString("auth-file"))
// if err != nil {
// log.Fatalf("auth-file:%s", err)
// }
// log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total())
// }
// if len(cfg.GetStringSlice("auth")) > 0 {
// httpAuthorization = true
// n := basicAuth.Add(cfg.GetStringSlice("auth"))
// log.Printf("auth data added %d, total:%d", n, basicAuth.Total())
// }
// }
// //listen
// sc := NewServerChannel(bindIP, bindPort)
// var err error
// if localIsTls {
// err = sc.ListenTls(certBytes, keyBytes, connHandler)
// } else {
// err = sc.ListenTCP(connHandler)
// }
// //listen fail
// if err != nil {
// log.Fatalf("ERR:%s", err)
// } else {
// log.Printf("%s %sproxy on %s", protocolStr, localIsTlsStr, (*sc.Listener).Addr())
// }
// }
// func connHandler(inConn net.Conn) {
// defer func() {
// err := recover()
// if err != nil {
// log.Printf("connHandler crashed,err:%s\nstack:%s", err, string(debug.Stack()))
// closeConn(&inConn)
// }
// }()
// if isTCP {
// tcpHandler(&inConn)
// } else {
// httpHandler(&inConn)
// }
// }
// func tcpHandler(inConn *net.Conn) {
// var outConn net.Conn
// var _outConn interface{}
// _outConn, err = outPool.Get()
// if err != nil {
// log.Printf("connect to %s , err:%s", proxyAddr, err)
// closeConn(inConn)
// return
// }
// outConn = _outConn.(net.Conn)
// inAddr := (*inConn).RemoteAddr().String()
// outAddr := outConn.RemoteAddr().String()
// IoBind((*inConn), outConn, func(err error) {
// log.Printf("conn %s - %s released", inAddr, outAddr)
// closeConn(inConn)
// closeConn(&outConn)
// }, func(n int, d bool) {}, 0)
// log.Printf("conn %s - %s connected", inAddr, outAddr)
// }
// func httpHandler(inConn *net.Conn) {
// var b [4096]byte
// var n int
// n, err = (*inConn).Read(b[:])
// if err != nil {
// if err != io.EOF {
// log.Printf("read err:%s", err)
// }
// closeConn(inConn)
// return
// }
// var method, host, address string
// index := bytes.IndexByte(b[:], '\n')
// if index == -1 {
// log.Printf("data err:%s", string(b[:n])[:50])
// closeConn(inConn)
// return
// }
// fmt.Sscanf(string(b[:index]), "%s%s", &method, &host)
// if method == "" || host == "" {
// log.Printf("data err:%s", string(b[:n])[:50])
// closeConn(inConn)
// return
// }
// isHTTPS := method == "CONNECT"
// //http basic auth,only http
// if !isHTTPS {
// if httpAuthorization {
// //log.Printf("request :%s", string(b[:n]))
// authorization, err := getHeader("Authorization", b[:n])
// if err != nil {
// fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
// closeConn(inConn)
// return
// }
// //log.Printf("Authorization:%s", authorization)
// basic := strings.Fields(authorization)
// if len(basic) != 2 {
// log.Printf("authorization data error,ERR:%s", authorization)
// closeConn(inConn)
// return
// }
// user, err := base64.StdEncoding.DecodeString(basic[1])
// if err != nil {
// log.Printf("authorization data parse error,ERR:%s", err)
// closeConn(inConn)
// return
// }
// authOk := basicAuth.Check(string(user))
// //log.Printf("auth %s,%v", string(user), authOk)
// if !authOk {
// fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
// closeConn(inConn)
// return
// }
// }
// }
// var bytes []byte
// if isHTTPS { //https访问
// // [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234
// address = fixHost(host)
// if hostIsNoPort(host) { //host不带端口 默认443
// address = address + ":443"
// }
// } else { //http访问
// hostPortURL, err := url.Parse(host)
// if err != nil {
// log.Printf("url.Parse %s ERR:%s", host, err)
// closeConn(inConn)
// return
// }
// _host := fixHost(hostPortURL.Host)
// address = _host
// if hostIsNoPort(_host) { //host不带端口 默认80
// address = _host + ":80"
// }
// if _host != hostPortURL.Host {
// bytes = []byte(strings.Replace(string(b[:n]), hostPortURL.Host, _host, 1))
// host = strings.Replace(host, hostPortURL.Host, _host, 1)
// }
// }
// //get url , reslut host is the full url
// host, err = getURL(b[:n], host)
// // log.Printf("body:%s", string(b[:n]))
// // log.Printf("%s:%s", method, host)
// if err != nil {
// log.Printf("header data err:%s", err)
// closeConn(inConn)
// return
// }
// useProxy := false
// if proxyAddr != "" {
// if cfg.GetBool("always") {
// useProxy = true
// } else {
// if isHTTPS {
// checker.Add(address, true, method, "", nil)
// } else {
// if bytes != nil {
// checker.Add(address, false, method, host, bytes)
// } else {
// checker.Add(address, false, method, host, b[:n])
// }
// }
// useProxy, _, _ = checker.IsBlocked(address)
// }
// // var failN, successN uint
// // useProxy, failN, successN = checker.IsBlocked(address)
// //log.Printf("use proxy ? %s : %v ,fail:%d, success:%d", address, useProxy, failN, successN)
// //log.Printf("use proxy ? %s : %v", address, useProxy)
// }
// var outConn net.Conn
// var _outConn interface{}
// if useProxy {
// _outConn, err = outPool.Get()
// if err == nil {
// outConn = _outConn.(net.Conn)
// }
// } else {
// outConn, err = ConnectHost(address, connTimeout)
// }
// if err != nil {
// log.Printf("connect to %s , err:%s", address, err)
// closeConn(inConn)
// return
// }
// inAddr := (*inConn).RemoteAddr().String()
// outAddr := outConn.RemoteAddr().String()
// if isHTTPS {
// if useProxy {
// outConn.Write(b[:n])
// } else {
// fmt.Fprint(*inConn, "HTTP/1.1 200 Connection established\r\n\r\n")
// }
// } else {
// if bytes != nil {
// outConn.Write(bytes)
// } else {
// outConn.Write(b[:n])
// }
// }
// IoBind(*inConn, outConn, func(err error) {
// log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address)
// closeConn(inConn)
// closeConn(&outConn)
// }, func(n int, d bool) {}, 0)
// log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address)
// }
func closeConn(conn *net.Conn) {
if *conn != nil {
(*conn).SetDeadline(time.Now().Add(time.Millisecond))
(*conn).Close()
}
}
func keygen() (err error) {
cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
cmd = exec.Command("sh", "-c", `openssl req -new -key proxy.key -x509 -days 3650 -out proxy.crt -subj /C=CN/ST=BJ/O="Localhost Ltd"/CN=proxy`)
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
return
}