goproxy/main.go
2017-09-19 17:16:16 +08:00

383 lines
9.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/url"
"os"
"os/exec"
"os/signal"
"runtime/debug"
"strings"
"syscall"
"time"
)
const APP_VERSION = "2.2"
var (
checker Checker
proxyIsTls bool
localIsTls bool
proxyAddr string
isTCP bool
connTimeout int
certBytes []byte
keyBytes []byte
err error
outPool ConnPool
basicAuth BasicAuth
httpAuthorization bool
)
func init() {
err = initConfig()
if err != nil {
log.Printf("err : %s", err)
return
}
}
func main() {
//catch panic error
defer func() {
e := recover()
if e != nil {
log.Printf("err : %s,\ntrace:%s", e, string(debug.Stack()))
}
}()
//define command line args
proxyIsTls = cfg.GetBool("parent-tls")
localIsTls = cfg.GetBool("local-tls")
proxyAddr = cfg.GetString("parent")
bindIP := cfg.GetString("ip")
bindPort := cfg.GetInt("port")
timeout := cfg.GetInt("check-timeout")
connTimeout = cfg.GetInt("tcp-timeout")
interval := cfg.GetInt("check-interval")
certFile := cfg.GetString("cert")
keyFile := cfg.GetString("key")
blockedFile := cfg.GetString("blocked")
directFile := cfg.GetString("direct")
isTCP = cfg.GetBool("tcp")
poolInitSize := cfg.GetInt("pool-size")
//check args required
if proxyIsTls && proxyAddr == "" {
log.Fatalf("parent proxy address required")
}
//check tls cert&key file
if certFile == "" {
certFile = "proxy.crt"
}
if keyFile == "" {
keyFile = "proxy.key"
}
if proxyIsTls || localIsTls {
certBytes, err = ioutil.ReadFile(certFile)
if err != nil {
log.Printf("err : %s", err)
return
}
keyBytes, err = ioutil.ReadFile(keyFile)
if err != nil {
log.Printf("err : %s", err)
return
}
}
//init tls info string
var proxyIsTlsStr string
var localIsTlsStr string
protocolStr := "tcp"
if !isTCP {
protocolStr = "http(s)"
}
if proxyIsTls {
proxyIsTlsStr = "tls "
}
if localIsTls {
localIsTlsStr = "tls "
}
//init checker and pool if needed
if proxyAddr != "" {
if !isTCP && !cfg.GetBool("always") {
checker = NewChecker(timeout, int64(interval), blockedFile, directFile)
}
log.Printf("use %sparent %s proxy : %s", proxyIsTlsStr, protocolStr, proxyAddr)
initOutPool(proxyIsTls, certBytes, keyBytes, proxyAddr, connTimeout, poolInitSize, poolInitSize*2)
} else if isTCP {
log.Printf("tcp proxy need parent")
return
}
//init basic auth only in http mode
if !isTCP {
basicAuth = NewBasicAuth()
if cfg.GetString("auth-file") != "" {
httpAuthorization = true
n, err := basicAuth.AddFromFile(cfg.GetString("auth-file"))
if err != nil {
log.Fatalf("auth-file:%s", err)
}
log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total())
}
if len(cfg.GetStringSlice("auth")) > 0 {
httpAuthorization = true
n := basicAuth.Add(cfg.GetStringSlice("auth"))
log.Printf("auth data added %d, total:%d", n, basicAuth.Total())
}
}
//listen
sc := NewServerChannel(bindIP, bindPort)
var err error
if localIsTls {
err = sc.ListenTls(certBytes, keyBytes, connHandler)
} else {
err = sc.ListenTCP(connHandler)
}
//listen fail
if err != nil {
log.Fatalf("ERR:%s", err)
} else {
log.Printf("%s %sproxy on %s", protocolStr, localIsTlsStr, (*sc.Listener).Addr())
}
//block main()
signalChan := make(chan os.Signal, 1)
cleanupDone := make(chan bool)
signal.Notify(signalChan,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
go func() {
for _ = range signalChan {
if outPool != nil {
fmt.Println("\nReceived an interrupt, stopping services...")
outPool.ReleaseAll()
//time.Sleep(time.Second * 10)
// fmt.Println("done")
}
cleanupDone <- true
}
}()
<-cleanupDone
}
func connHandler(inConn net.Conn) {
defer func() {
err := recover()
if err != nil {
log.Printf("connHandler crashed,err:%s\nstack:%s", err, string(debug.Stack()))
closeConn(&inConn)
}
}()
if isTCP {
tcpHandler(&inConn)
} else {
httpHandler(&inConn)
}
}
func tcpHandler(inConn *net.Conn) {
var outConn net.Conn
var _outConn interface{}
_outConn, err = outPool.Get()
if err != nil {
log.Printf("connect to %s , err:%s", proxyAddr, err)
closeConn(inConn)
return
}
outConn = _outConn.(net.Conn)
inAddr := (*inConn).RemoteAddr().String()
outAddr := outConn.RemoteAddr().String()
IoBind((*inConn), outConn, func(err error) {
log.Printf("conn %s - %s released", inAddr, outAddr)
closeConn(inConn)
closeConn(&outConn)
}, func(n int, d bool) {}, 0)
log.Printf("conn %s - %s connected", inAddr, outAddr)
}
func httpHandler(inConn *net.Conn) {
var b [4096]byte
var n int
n, err = (*inConn).Read(b[:])
if err != nil {
if err != io.EOF {
log.Printf("read err:%s", err)
}
closeConn(inConn)
return
}
var method, host, address string
index := bytes.IndexByte(b[:], '\n')
if index == -1 {
log.Printf("data err:%s", string(b[:n])[:50])
closeConn(inConn)
return
}
fmt.Sscanf(string(b[:index]), "%s%s", &method, &host)
if method == "" || host == "" {
log.Printf("data err:%s", string(b[:n])[:50])
closeConn(inConn)
return
}
isHTTPS := method == "CONNECT"
//http basic auth,only http
if !isHTTPS {
if httpAuthorization {
//log.Printf("request :%s", string(b[:n]))
authorization, err := getHeader("Authorization", b[:n])
if err != nil {
fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
closeConn(inConn)
return
}
//log.Printf("Authorization:%s", authorization)
basic := strings.Fields(authorization)
if len(basic) != 2 {
log.Printf("authorization data error,ERR:%s", authorization)
closeConn(inConn)
return
}
user, err := base64.StdEncoding.DecodeString(basic[1])
if err != nil {
log.Printf("authorization data parse error,ERR:%s", err)
closeConn(inConn)
return
}
authOk := basicAuth.Check(string(user))
//log.Printf("auth %s,%v", string(user), authOk)
if !authOk {
fmt.Fprint(*inConn, "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
closeConn(inConn)
return
}
}
}
var bytes []byte
if isHTTPS { //https访问
// [dd:dafds:fsd:dasd:2.2.23.3] or 2.2.23.3 or [dd:dafds:fsd:dasd:2.2.23.3]:2323 or 2.2.23.3:1234
address = fixHost(host)
if hostIsNoPort(host) { //host不带端口 默认443
address = address + ":443"
}
} else { //http访问
hostPortURL, err := url.Parse(host)
if err != nil {
log.Printf("url.Parse %s ERR:%s", host, err)
closeConn(inConn)
return
}
_host := fixHost(hostPortURL.Host)
address = _host
if hostIsNoPort(_host) { //host不带端口 默认80
address = _host + ":80"
}
if _host != hostPortURL.Host {
bytes = []byte(strings.Replace(string(b[:n]), hostPortURL.Host, _host, 1))
host = strings.Replace(host, hostPortURL.Host, _host, 1)
}
}
//get url , reslut host is the full url
host, err = getURL(b[:n], host)
// log.Printf("body:%s", string(b[:n]))
// log.Printf("%s:%s", method, host)
if err != nil {
log.Printf("header data err:%s", err)
closeConn(inConn)
return
}
useProxy := false
if proxyAddr != "" {
if cfg.GetBool("always") {
useProxy = true
} else {
if isHTTPS {
checker.Add(address, true, method, "", nil)
} else {
if bytes != nil {
checker.Add(address, false, method, host, bytes)
} else {
checker.Add(address, false, method, host, b[:n])
}
}
useProxy, _, _ = checker.IsBlocked(address)
}
// var failN, successN uint
// useProxy, failN, successN = checker.IsBlocked(address)
//log.Printf("use proxy ? %s : %v ,fail:%d, success:%d", address, useProxy, failN, successN)
//log.Printf("use proxy ? %s : %v", address, useProxy)
}
var outConn net.Conn
var _outConn interface{}
if useProxy {
_outConn, err = outPool.Get()
if err == nil {
outConn = _outConn.(net.Conn)
}
} else {
outConn, err = ConnectHost(address, connTimeout)
}
if err != nil {
log.Printf("connect to %s , err:%s", address, err)
closeConn(inConn)
return
}
inAddr := (*inConn).RemoteAddr().String()
outAddr := outConn.RemoteAddr().String()
if isHTTPS {
if useProxy {
outConn.Write(b[:n])
} else {
fmt.Fprint(*inConn, "HTTP/1.1 200 Connection established\r\n\r\n")
}
} else {
if bytes != nil {
outConn.Write(bytes)
} else {
outConn.Write(b[:n])
}
}
IoBind(*inConn, outConn, func(err error) {
log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address)
closeConn(inConn)
closeConn(&outConn)
}, func(n int, d bool) {}, 0)
log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address)
}
func closeConn(conn *net.Conn) {
if *conn != nil {
(*conn).SetDeadline(time.Now().Add(time.Millisecond))
(*conn).Close()
}
}
func keygen() (err error) {
cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
cmd = exec.Command("sh", "-c", `openssl req -new -key proxy.key -x509 -days 3650 -out proxy.crt -subj /C=CN/ST=BJ/O="Localhost Ltd"/CN=proxy`)
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
return
}