goproxy/functions.go
2017-09-20 19:28:36 +08:00

359 lines
8.8 KiB
Go
Executable File

package main
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"runtime/debug"
"strconv"
"strings"
"syscall"
"time"
)
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
errchn := make(chan error, 2)
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
if bytesPreSec > 0 {
newreader := NewReader(src)
newreader.SetRateLimit(bytesPreSec)
_, err = ioCopy(dst, newreader, func(c int) {
cfn(c, false)
})
} else {
_, err = ioCopy(dst, src, func(c int) {
cfn(c, false)
})
}
errchn <- err
}()
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
if bytesPreSec > 0 {
newReader := NewReader(dst)
newReader.SetRateLimit(bytesPreSec)
_, err = ioCopy(src, newReader, func(c int) {
cfn(c, true)
})
} else {
_, err = ioCopy(src, dst, func(c int) {
cfn(c, true)
})
}
errchn <- err
}()
fn(<-errchn)
}()
}
func ioCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
if len(fn) == 1 {
fn[0](nw)
}
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
err = er
break
}
}
return written, err
}
func TlsConnectHost(host string, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
h := strings.Split(host, ":")
port, _ := strconv.Atoi(h[1])
return TlsConnect(h[0], port, timeout, certBytes, keyBytes)
}
func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
conf, err := getRequestTlsConfig(certBytes, keyBytes)
if err != nil {
return
}
_conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
if err != nil {
return
}
return *tls.Client(_conn, conf), err
}
func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err error) {
var cert tls.Certificate
cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return
}
serverCertPool := x509.NewCertPool()
ok := serverCertPool.AppendCertsFromPEM(certBytes)
if !ok {
err = errors.New("failed to parse root certificate")
}
conf = &tls.Config{
RootCAs: serverCertPool,
Certificates: []tls.Certificate{cert},
ServerName: "proxy",
InsecureSkipVerify: false,
}
return
}
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
return
}
func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listener, err error) {
var cert tls.Certificate
cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return
}
clientCertPool := x509.NewCertPool()
ok := clientCertPool.AppendCertsFromPEM(certBytes)
if !ok {
err = errors.New("failed to parse root certificate")
}
config := &tls.Config{
ClientCAs: clientCertPool,
ServerName: "proxy",
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
}
_ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config)
if err == nil {
ln = &_ln
}
return
}
func PathExists(_path string) bool {
_, err := os.Stat(_path)
if err != nil && os.IsNotExist(err) {
return false
}
return true
}
func HTTPGet(URL string, timeout int) (err error) {
tr := &http.Transport{}
var resp *http.Response
var client *http.Client
defer func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
tr.CloseIdleConnections()
}()
client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
resp, err = client.Get(URL)
if err != nil {
return
}
return
}
func initOutPool(isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) {
var err error
outPool, err = NewConnPool(poolConfig{
IsActive: func(conn interface{}) bool { return true },
Release: func(conn interface{}) {
if conn != nil {
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
conn.(net.Conn).Close()
// log.Println("conn released")
}
},
InitialCap: InitialCap,
MaxCap: MaxCap,
Factory: func() (conn interface{}, err error) {
conn, err = getConn(isTLS, certBytes, keyBytes, address, timeout)
return
},
})
if err != nil {
log.Fatalf("init conn pool fail ,%s", err)
} else {
log.Printf("init conn pool success")
initPoolDeamon(isTLS, certBytes, keyBytes, address, timeout)
}
}
func getConn(isTLS bool, certBytes, keyBytes []byte, address string, timeout int) (conn interface{}, err error) {
if isTLS {
var _conn tls.Conn
_conn, err = TlsConnectHost(address, timeout, certBytes, keyBytes)
if err == nil {
conn = net.Conn(&_conn)
}
} else {
conn, err = ConnectHost(address, timeout)
}
return
}
func initPoolDeamon(isTLS bool, certBytes, keyBytes []byte, address string, timeout int) {
go func() {
dur := cfg.GetInt("check-proxy-interval")
if dur <= 0 {
return
}
log.Printf("pool deamon started")
for {
time.Sleep(time.Second * time.Duration(dur))
conn, err := getConn(isTLS, certBytes, keyBytes, address, timeout)
if err != nil {
log.Printf("pool deamon err %s , release pool", err)
outPool.ReleaseAll()
} else {
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
conn.(net.Conn).Close()
}
}
}()
}
func IsBasicAuth() bool {
return cfg.GetString("auth-file") != "" || len(cfg.GetStringSlice("auth")) > 0
}
func InitBasicAuth() (err error) {
basicAuth = NewBasicAuth()
if cfg.GetString("auth-file") != "" {
n, err := basicAuth.AddFromFile(cfg.GetString("auth-file"))
if err != nil {
return fmt.Errorf("auth-file ERR:%s", err)
}
log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total())
}
if len(cfg.GetStringSlice("auth")) > 0 {
n := basicAuth.Add(cfg.GetStringSlice("auth"))
log.Printf("auth data added %d, total:%d", n, basicAuth.Total())
}
return
}
func clean() {
//block main()
signalChan := make(chan os.Signal, 1)
cleanupDone := make(chan bool)
signal.Notify(signalChan,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
go func() {
for _ = range signalChan {
if outPool != nil {
fmt.Println("\nReceived an interrupt, stopping services...")
outPool.ReleaseAll()
//time.Sleep(time.Second * 10)
// fmt.Println("done")
}
cleanupDone <- true
}
}()
<-cleanupDone
}
func HTTPProxyDecoder(inConn *net.Conn) (useProxy bool, request *HTTPRequest, err error) {
var req HTTPRequest
req, err = NewHTTPRequest(inConn, 4096)
if err != nil {
//log.Printf("NewHTTPRequest ERR:%s", err)
return
}
useProxy = false
if checker.data != nil {
useProxy, _, _ = checker.IsBlocked(req.Host)
}
request = &req
return
}
// type sockaddr struct {
// family uint16
// data [14]byte
// }
// const SO_ORIGINAL_DST = 80
// realServerAddress returns an intercepted connection's original destination.
// func realServerAddress(conn *net.Conn) (string, error) {
// tcpConn, ok := (*conn).(*net.TCPConn)
// if !ok {
// return "", errors.New("not a TCPConn")
// }
// file, err := tcpConn.File()
// if err != nil {
// return "", err
// }
// // To avoid potential problems from making the socket non-blocking.
// tcpConn.Close()
// *conn, err = net.FileConn(file)
// if err != nil {
// return "", err
// }
// defer file.Close()
// fd := file.Fd()
// var addr sockaddr
// size := uint32(unsafe.Sizeof(addr))
// err = getsockopt(int(fd), syscall.SOL_IP, SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&addr)), &size)
// if err != nil {
// return "", err
// }
// var ip net.IP
// switch addr.family {
// case syscall.AF_INET:
// ip = addr.data[2:6]
// default:
// return "", errors.New("unrecognized address family")
// }
// port := int(addr.data[0])<<8 + int(addr.data[1])
// return net.JoinHostPort(ip.String(), strconv.Itoa(port)), nil
// }
// func getsockopt(s int, level int, name int, val uintptr, vallen *uint32) (err error) {
// _, _, e1 := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(unsafe.Pointer(vallen)), 0)
// if e1 != 0 {
// err = e1
// }
// return
// }