Signed-off-by: arraykeys@gmail.com <arraykeys@gmail.com>

This commit is contained in:
arraykeys@gmail.com
2018-03-12 17:31:35 +08:00
parent f87cbf73e8
commit 34e9e362b9
16 changed files with 194 additions and 108 deletions

View File

@ -73,6 +73,7 @@ func initConfig() (err error) {
//########http######### //########http#########
http := app.Command("http", "proxy on http mode") http := app.Command("http", "proxy on http mode")
httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
httpArgs.CaCertFile = http.Flag("ca", "ca cert file for tls").Default("").String()
httpArgs.CertFile = http.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() httpArgs.CertFile = http.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
httpArgs.KeyFile = http.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() httpArgs.KeyFile = http.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
httpArgs.LocalType = http.Flag("local-type", "local protocol type <tls|tcp|kcp>").Default("tcp").Short('t').Enum("tls", "tcp", "kcp") httpArgs.LocalType = http.Flag("local-type", "local protocol type <tls|tcp|kcp>").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")
@ -189,6 +190,7 @@ func initConfig() (err error) {
socksArgs.UDPParent = socks.Flag("udp-parent", "udp parent address, such as: \"23.32.32.19:33090\"").Default("").Short('X').String() socksArgs.UDPParent = socks.Flag("udp-parent", "udp parent address, such as: \"23.32.32.19:33090\"").Default("").Short('X').String()
socksArgs.UDPLocal = socks.Flag("udp-local", "udp local ip:port to listen").Short('x').Default(":33090").String() socksArgs.UDPLocal = socks.Flag("udp-local", "udp local ip:port to listen").Short('x').Default(":33090").String()
socksArgs.CertFile = socks.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() socksArgs.CertFile = socks.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
socksArgs.CaCertFile = socks.Flag("ca", "ca cert file for tls").Default("").String()
socksArgs.KeyFile = socks.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() socksArgs.KeyFile = socks.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
socksArgs.SSHUser = socks.Flag("ssh-user", "user for ssh").Short('u').Default("").String() socksArgs.SSHUser = socks.Flag("ssh-user", "user for ssh").Short('u').Default("").String()
socksArgs.SSHKeyFile = socks.Flag("ssh-key", "private key file for ssh").Short('S').Default("").String() socksArgs.SSHKeyFile = socks.Flag("ssh-key", "private key file for ssh").Short('S').Default("").String()
@ -213,6 +215,7 @@ func initConfig() (err error) {
spsArgs.Parent = sps.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() spsArgs.Parent = sps.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
spsArgs.CertFile = sps.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() spsArgs.CertFile = sps.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
spsArgs.KeyFile = sps.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() spsArgs.KeyFile = sps.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
spsArgs.CaCertFile = sps.Flag("ca", "ca cert file for tls").Default("").String()
spsArgs.Timeout = sps.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Short('i').Default("2000").Int() spsArgs.Timeout = sps.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Short('i').Default("2000").Int()
spsArgs.ParentType = sps.Flag("parent-type", "parent protocol type <tls|tcp|kcp>").Short('T').Enum("tls", "tcp", "kcp") spsArgs.ParentType = sps.Flag("parent-type", "parent protocol type <tls|tcp|kcp>").Short('T').Enum("tls", "tcp", "kcp")
spsArgs.LocalType = sps.Flag("local-type", "local protocol type <tls|tcp|kcp>").Default("tcp").Short('t').Enum("tls", "tcp", "kcp") spsArgs.LocalType = sps.Flag("local-type", "local protocol type <tls|tcp|kcp>").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")

View File

@ -120,6 +120,8 @@ type HTTPArgs struct {
Parent *string Parent *string
CertFile *string CertFile *string
KeyFile *string KeyFile *string
CaCertFile *string
CaCertBytes []byte
CertBytes []byte CertBytes []byte
KeyBytes []byte KeyBytes []byte
Local *string Local *string
@ -169,6 +171,8 @@ type SocksArgs struct {
LocalType *string LocalType *string
CertFile *string CertFile *string
KeyFile *string KeyFile *string
CaCertFile *string
CaCertBytes []byte
CertBytes []byte CertBytes []byte
KeyBytes []byte KeyBytes []byte
SSHKeyFile *string SSHKeyFile *string
@ -199,6 +203,8 @@ type SPSArgs struct {
Parent *string Parent *string
CertFile *string CertFile *string
KeyFile *string KeyFile *string
CaCertFile *string
CaCertBytes []byte
CertBytes []byte CertBytes []byte
KeyBytes []byte KeyBytes []byte
Local *string Local *string

View File

@ -41,6 +41,12 @@ func (s *HTTP) CheckArgs() {
} }
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" { if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
if *s.cfg.CaCertFile != "" {
s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile)
if err != nil {
log.Fatalf("read ca file error,ERR:%s", err)
}
}
} }
if *s.cfg.ParentType == "ssh" { if *s.cfg.ParentType == "ssh" {
if *s.cfg.SSHUser == "" { if *s.cfg.SSHUser == "" {
@ -128,7 +134,7 @@ func (s *HTTP) Start(args interface{}) (err error) {
if *s.cfg.LocalType == TYPE_TCP { if *s.cfg.LocalType == TYPE_TCP {
err = sc.ListenTCP(s.callback) err = sc.ListenTCP(s.callback)
} else if *s.cfg.LocalType == TYPE_TLS { } else if *s.cfg.LocalType == TYPE_TLS {
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes, s.callback)
} else if *s.cfg.LocalType == TYPE_KCP { } else if *s.cfg.LocalType == TYPE_KCP {
err = sc.ListenKCP(s.cfg.KCP, s.callback) err = sc.ListenKCP(s.cfg.KCP, s.callback)
} }
@ -321,7 +327,7 @@ func (s *HTTP) InitOutConnPool() {
*s.cfg.CheckParentInterval, *s.cfg.CheckParentInterval,
*s.cfg.ParentType, *s.cfg.ParentType,
s.cfg.KCP, s.cfg.KCP,
s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes,
s.Resolve(*s.cfg.Parent), s.Resolve(*s.cfg.Parent),
*s.cfg.Timeout, *s.cfg.Timeout,
*s.cfg.PoolSize, *s.cfg.PoolSize,

View File

@ -56,7 +56,7 @@ func (s *MuxBridge) Start(args interface{}) (err error) {
if *s.cfg.LocalType == TYPE_TCP { if *s.cfg.LocalType == TYPE_TCP {
err = sc.ListenTCP(s.handler) err = sc.ListenTCP(s.handler)
} else if *s.cfg.LocalType == TYPE_TLS { } else if *s.cfg.LocalType == TYPE_TLS {
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.handler) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.handler)
} else if *s.cfg.LocalType == TYPE_KCP { } else if *s.cfg.LocalType == TYPE_KCP {
err = sc.ListenKCP(s.cfg.KCP, s.handler) err = sc.ListenKCP(s.cfg.KCP, s.handler)
} }

View File

@ -126,7 +126,7 @@ func (s *MuxClient) Clean() {
func (s *MuxClient) getParentConn() (conn net.Conn, err error) { func (s *MuxClient) getParentConn() (conn net.Conn, err error) {
if *s.cfg.ParentType == "tls" { if *s.cfg.ParentType == "tls" {
var _conn tls.Conn var _conn tls.Conn
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
if err == nil { if err == nil {
conn = net.Conn(&_conn) conn = net.Conn(&_conn)
} }

View File

@ -290,7 +290,7 @@ func (s *MuxServer) GetConn(index string) (conn net.Conn, err error) {
func (s *MuxServer) getParentConn() (conn net.Conn, err error) { func (s *MuxServer) getParentConn() (conn net.Conn, err error) {
if *s.cfg.ParentType == "tls" { if *s.cfg.ParentType == "tls" {
var _conn tls.Conn var _conn tls.Conn
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
if err == nil { if err == nil {
conn = net.Conn(&_conn) conn = net.Conn(&_conn)
} }

View File

@ -37,16 +37,28 @@ func NewSocks() Service {
func (s *Socks) CheckArgs() { func (s *Socks) CheckArgs() {
var err error var err error
if *s.cfg.LocalType == "tls" { if *s.cfg.LocalType == "tls" || (*s.cfg.Parent != "" && *s.cfg.ParentType == "tls") {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
if *s.cfg.CaCertFile != "" {
s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile)
if err != nil {
log.Fatalf("read ca file error,ERR:%s", err)
}
}
} }
if *s.cfg.Parent != "" { if *s.cfg.Parent != "" {
if *s.cfg.ParentType == "" { if *s.cfg.ParentType == "" {
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>") log.Fatalf("parent type unkown,use -T <tls|tcp|ssh|kcp>")
} }
if *s.cfg.ParentType == "tls" { // if *s.cfg.ParentType == "tls" {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) // s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
} // if *s.cfg.CaCertFile != "" {
// s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile)
// if err != nil {
// log.Fatalf("read ca file error,ERR:%s", err)
// }
// }
// }
if *s.cfg.ParentType == "ssh" { if *s.cfg.ParentType == "ssh" {
if *s.cfg.SSHUser == "" { if *s.cfg.SSHUser == "" {
log.Fatalf("ssh user required") log.Fatalf("ssh user required")
@ -138,7 +150,7 @@ func (s *Socks) Start(args interface{}) (err error) {
if *s.cfg.LocalType == TYPE_TCP { if *s.cfg.LocalType == TYPE_TCP {
err = sc.ListenTCP(s.socksConnCallback) err = sc.ListenTCP(s.socksConnCallback)
} else if *s.cfg.LocalType == TYPE_TLS { } else if *s.cfg.LocalType == TYPE_TLS {
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.socksConnCallback) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.socksConnCallback)
} else if *s.cfg.LocalType == TYPE_KCP { } else if *s.cfg.LocalType == TYPE_KCP {
err = sc.ListenKCP(s.cfg.KCP, s.socksConnCallback) err = sc.ListenKCP(s.cfg.KCP, s.socksConnCallback)
} }
@ -471,7 +483,7 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n
case "tcp": case "tcp":
if *s.cfg.ParentType == "tls" { if *s.cfg.ParentType == "tls" {
var _outConn tls.Conn var _outConn tls.Conn
_outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) _outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
outConn = net.Conn(&_outConn) outConn = net.Conn(&_outConn)
} else if *s.cfg.ParentType == "kcp" { } else if *s.cfg.ParentType == "kcp" {
outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), s.cfg.KCP) outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), s.cfg.KCP)

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net" "net"
"runtime/debug" "runtime/debug"
@ -34,6 +35,13 @@ func (s *SPS) CheckArgs() {
} }
if *s.cfg.ParentType == TYPE_TLS || *s.cfg.LocalType == TYPE_TLS { if *s.cfg.ParentType == TYPE_TLS || *s.cfg.LocalType == TYPE_TLS {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
if *s.cfg.CaCertFile != "" {
var err error
s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile)
if err != nil {
log.Fatalf("read ca file error,ERR:%s", err)
}
}
} }
} }
func (s *SPS) InitService() { func (s *SPS) InitService() {
@ -47,7 +55,7 @@ func (s *SPS) InitOutConnPool() {
0, 0,
*s.cfg.ParentType, *s.cfg.ParentType,
s.cfg.KCP, s.cfg.KCP,
s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CertBytes, s.cfg.KeyBytes, nil,
*s.cfg.Parent, *s.cfg.Parent,
*s.cfg.Timeout, *s.cfg.Timeout,
0, 0,
@ -75,7 +83,7 @@ func (s *SPS) Start(args interface{}) (err error) {
if *s.cfg.LocalType == TYPE_TCP { if *s.cfg.LocalType == TYPE_TCP {
err = sc.ListenTCP(s.callback) err = sc.ListenTCP(s.callback)
} else if *s.cfg.LocalType == TYPE_TLS { } else if *s.cfg.LocalType == TYPE_TLS {
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback)
} else if *s.cfg.LocalType == TYPE_KCP { } else if *s.cfg.LocalType == TYPE_KCP {
err = sc.ListenKCP(s.cfg.KCP, s.callback) err = sc.ListenKCP(s.cfg.KCP, s.callback)
} }

View File

@ -56,7 +56,7 @@ func (s *TCP) Start(args interface{}) (err error) {
if *s.cfg.LocalType == TYPE_TCP { if *s.cfg.LocalType == TYPE_TCP {
err = sc.ListenTCP(s.callback) err = sc.ListenTCP(s.callback)
} else if *s.cfg.LocalType == TYPE_TLS { } else if *s.cfg.LocalType == TYPE_TLS {
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback)
} else if *s.cfg.LocalType == TYPE_KCP { } else if *s.cfg.LocalType == TYPE_KCP {
err = sc.ListenKCP(s.cfg.KCP, s.callback) err = sc.ListenKCP(s.cfg.KCP, s.callback)
} }
@ -172,7 +172,7 @@ func (s *TCP) InitOutConnPool() {
*s.cfg.CheckParentInterval, *s.cfg.CheckParentInterval,
*s.cfg.ParentType, *s.cfg.ParentType,
s.cfg.KCP, s.cfg.KCP,
s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CertBytes, s.cfg.KeyBytes, nil,
*s.cfg.Parent, *s.cfg.Parent,
*s.cfg.Timeout, *s.cfg.Timeout,
*s.cfg.PoolSize, *s.cfg.PoolSize,

View File

@ -51,7 +51,7 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
p, _ := strconv.Atoi(port) p, _ := strconv.Atoi(port)
sc := utils.NewServerChannel(host, p) sc := utils.NewServerChannel(host, p)
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) { err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, func(inConn net.Conn) {
//log.Printf("connection from %s ", inConn.RemoteAddr()) //log.Printf("connection from %s ", inConn.RemoteAddr())
reader := bufio.NewReader(inConn) reader := bufio.NewReader(inConn)

View File

@ -161,7 +161,7 @@ func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, e
} }
func (s *TunnelClient) GetConn() (conn net.Conn, err error) { func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
var _conn tls.Conn var _conn tls.Conn
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
if err == nil { if err == nil {
conn = net.Conn(&_conn) conn = net.Conn(&_conn)
} }

View File

@ -174,7 +174,7 @@ func (s *TunnelServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string
} }
func (s *TunnelServerManager) GetConn() (conn net.Conn, err error) { func (s *TunnelServerManager) GetConn() (conn net.Conn, err error) {
var _conn tls.Conn var _conn tls.Conn
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
if err == nil { if err == nil {
conn = net.Conn(&_conn) conn = net.Conn(&_conn)
} }
@ -280,7 +280,7 @@ func (s *TunnelServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err e
} }
func (s *TunnelServer) GetConn() (conn net.Conn, err error) { func (s *TunnelServer) GetConn() (conn net.Conn, err error) {
var _conn tls.Conn var _conn tls.Conn
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
if err == nil { if err == nil {
conn = net.Conn(&_conn) conn = net.Conn(&_conn)
} }

View File

@ -210,7 +210,7 @@ func (s *UDP) InitOutConnPool() {
*s.cfg.CheckParentInterval, *s.cfg.CheckParentInterval,
*s.cfg.ParentType, *s.cfg.ParentType,
kcpcfg.KCPConfigArgs{}, kcpcfg.KCPConfigArgs{},
s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CertBytes, s.cfg.KeyBytes, nil,
*s.cfg.Parent, *s.cfg.Parent,
*s.cfg.Timeout, *s.cfg.Timeout,
*s.cfg.PoolSize, *s.cfg.PoolSize,

View File

@ -86,14 +86,14 @@ func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) {
} }
} }
} }
func TlsConnectHost(host string, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) { func TlsConnectHost(host string, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
h := strings.Split(host, ":") h := strings.Split(host, ":")
port, _ := strconv.Atoi(h[1]) port, _ := strconv.Atoi(h[1])
return TlsConnect(h[0], port, timeout, certBytes, keyBytes) return TlsConnect(h[0], port, timeout, certBytes, keyBytes, caCertBytes)
} }
func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) { func TlsConnect(host string, port, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
conf, err := getRequestTlsConfig(certBytes, keyBytes) conf, err := getRequestTlsConfig(certBytes, keyBytes, caCertBytes)
if err != nil { if err != nil {
return return
} }
@ -103,8 +103,24 @@ func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (con
} }
return *tls.Client(_conn, conf), err return *tls.Client(_conn, conf), err
} }
func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err error) { func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) {
block, _ := pem.Decode(certBytes)
var cert tls.Certificate
cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return
}
serverCertPool := x509.NewCertPool()
caBytes := certBytes
if caCertBytes != nil {
caBytes = caCertBytes
}
ok := serverCertPool.AppendCertsFromPEM(caBytes)
if !ok {
err = errors.New("failed to parse root certificate")
}
block, _ := pem.Decode(caBytes)
if block == nil { if block == nil {
panic("failed to parse certificate PEM") panic("failed to parse certificate PEM")
} }
@ -112,34 +128,24 @@ func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err erro
if x509Cert == nil { if x509Cert == nil {
panic("failed to parse block") panic("failed to parse block")
} }
var cert tls.Certificate
cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return
}
serverCertPool := x509.NewCertPool()
ok := serverCertPool.AppendCertsFromPEM(certBytes)
if !ok {
err = errors.New("failed to parse root certificate")
}
conf = &tls.Config{ conf = &tls.Config{
RootCAs: serverCertPool, RootCAs: serverCertPool,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
InsecureSkipVerify: false, InsecureSkipVerify: true,
ServerName: x509Cert.Subject.CommonName, ServerName: x509Cert.Subject.CommonName,
// VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
// Roots: serverCertPool, Roots: serverCertPool,
// } }
// for _, rawCert := range rawCerts { for _, rawCert := range rawCerts {
// cert, _ := x509.ParseCertificate(rawCert) cert, _ := x509.ParseCertificate(rawCert)
// _, err := cert.Verify(opts) _, err := cert.Verify(opts)
// if err != nil { if err != nil {
// return err return err
// } }
// } }
// return nil return nil
// }, },
} }
return return
} }
@ -165,22 +171,19 @@ func ConnectKCPHost(hostAndPort string, config kcpcfg.KCPConfigArgs) (conn net.C
return NewCompStream(kcpconn), err return NewCompStream(kcpconn), err
} }
func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listener, err error) { func ListenTls(ip string, port int, certBytes, keyBytes, caCertBytes []byte) (ln *net.Listener, err error) {
block, _ := pem.Decode(certBytes)
if block == nil {
panic("failed to parse certificate PEM")
}
x509Cert, _ := x509.ParseCertificate(block.Bytes)
if x509Cert == nil {
panic("failed to parse block")
}
var cert tls.Certificate var cert tls.Certificate
cert, err = tls.X509KeyPair(certBytes, keyBytes) cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil { if err != nil {
return return
} }
clientCertPool := x509.NewCertPool() clientCertPool := x509.NewCertPool()
ok := clientCertPool.AppendCertsFromPEM(certBytes) caBytes := certBytes
if caCertBytes != nil {
caBytes = caCertBytes
}
ok := clientCertPool.AppendCertsFromPEM(caBytes)
if !ok { if !ok {
err = errors.New("failed to parse root certificate") err = errors.New("failed to parse root certificate")
} }
@ -188,21 +191,6 @@ func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listene
ClientCAs: clientCertPool, ClientCAs: clientCertPool,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert, ClientAuth: tls.RequireAndVerifyClientCert,
ServerName: x509Cert.Subject.CommonName,
// VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// opts := x509.VerifyOptions{
// Roots: clientCertPool,
// }
// for _, rawCert := range rawCerts {
// cert, _ := x509.ParseCertificate(rawCert)
// _, err := cert.Verify(opts)
// fmt.Println("SERVER ERR:", err)
// if err != nil {
// return err
// }
// }
// return nil
// },
} }
_ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config) _ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config)
if err == nil { if err == nil {
@ -245,27 +233,88 @@ func CloseConn(conn *net.Conn) {
} }
} }
func Keygen() (err error) { func Keygen() (err error) {
cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
CList := []string{"AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BL", "BM", "BN", "BO", "BR", "BS", "BW", "BY", "BZ", "CA", "CF", "CG", "CH", "CK", "CL", "CM", "CN", "CO", "CR", "CS", "CU", "CY", "CZ", "DE", "DJ", "DK", "DO", "DZ", "EC", "EE", "EG", "ES", "ET", "FI", "FJ", "FR", "GA", "GB", "GD", "GE", "GF", "GH", "GI", "GM", "GN", "GR", "GT", "GU", "GY", "HK", "HN", "HT", "HU", "ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG", "KH", "KP", "KR", "KT", "KW", "KZ", "LA", "LB", "LC", "LI", "LK", "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "MG", "ML", "MM", "MN", "MO", "MS", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE", "NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PF", "PG", "PH", "PK", "PL", "PR", "PT", "PY", "QA", "RO", "RU", "SA", "SB", "SC", "SD", "SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ", "TD", "TG", "TH", "TJ", "TM", "TN", "TO", "TR", "TT", "TW", "TZ", "UA", "UG", "US", "UY", "UZ", "VC", "VE", "VN", "YE", "YU", "ZA", "ZM", "ZR", "ZW"} CList := []string{"AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BL", "BM", "BN", "BO", "BR", "BS", "BW", "BY", "BZ", "CA", "CF", "CG", "CH", "CK", "CL", "CM", "CN", "CO", "CR", "CS", "CU", "CY", "CZ", "DE", "DJ", "DK", "DO", "DZ", "EC", "EE", "EG", "ES", "ET", "FI", "FJ", "FR", "GA", "GB", "GD", "GE", "GF", "GH", "GI", "GM", "GN", "GR", "GT", "GU", "GY", "HK", "HN", "HT", "HU", "ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG", "KH", "KP", "KR", "KT", "KW", "KZ", "LA", "LB", "LC", "LI", "LK", "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "MG", "ML", "MM", "MN", "MO", "MS", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE", "NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PF", "PG", "PH", "PK", "PL", "PR", "PT", "PY", "QA", "RO", "RU", "SA", "SB", "SC", "SD", "SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ", "TD", "TG", "TH", "TJ", "TM", "TN", "TO", "TR", "TT", "TW", "TZ", "UA", "UG", "US", "UY", "UZ", "VC", "VE", "VN", "YE", "YU", "ZA", "ZM", "ZR", "ZW"}
domainSubfixList := []string{".com", ".edu", ".gov", ".int", ".mil", ".net", ".org", ".biz", ".info", ".pro", ".name", ".museum", ".coop", ".aero", ".xxx", ".idv", ".ac", ".ad", ".ae", ".af", ".ag", ".ai", ".al", ".am", ".an", ".ao", ".aq", ".ar", ".as", ".at", ".au", ".aw", ".az", ".ba", ".bb", ".bd", ".be", ".bf", ".bg", ".bh", ".bi", ".bj", ".bm", ".bn", ".bo", ".br", ".bs", ".bt", ".bv", ".bw", ".by", ".bz", ".ca", ".cc", ".cd", ".cf", ".cg", ".ch", ".ci", ".ck", ".cl", ".cm", ".cn", ".co", ".cr", ".cu", ".cv", ".cx", ".cy", ".cz", ".de", ".dj", ".dk", ".dm", ".do", ".dz", ".ec", ".ee", ".eg", ".eh", ".er", ".es", ".et", ".eu", ".fi", ".fj", ".fk", ".fm", ".fo", ".fr", ".ga", ".gd", ".ge", ".gf", ".gg", ".gh", ".gi", ".gl", ".gm", ".gn", ".gp", ".gq", ".gr", ".gs", ".gt", ".gu", ".gw", ".gy", ".hk", ".hm", ".hn", ".hr", ".ht", ".hu", ".id", ".ie", ".il", ".im", ".in", ".io", ".iq", ".ir", ".is", ".it", ".je", ".jm", ".jo", ".jp", ".ke", ".kg", ".kh", ".ki", ".km", ".kn", ".kp", ".kr", ".kw", ".ky", ".kz", ".la", ".lb", ".lc", ".li", ".lk", ".lr", ".ls", ".lt", ".lu", ".lv", ".ly", ".ma", ".mc", ".md", ".mg", ".mh", ".mk", ".ml", ".mm", ".mn", ".mo", ".mp", ".mq", ".mr", ".ms", ".mt", ".mu", ".mv", ".mw", ".mx", ".my", ".mz", ".na", ".nc", ".ne", ".nf", ".ng", ".ni", ".nl", ".no", ".np", ".nr", ".nu", ".nz", ".om", ".pa", ".pe", ".pf", ".pg", ".ph", ".pk", ".pl", ".pm", ".pn", ".pr", ".ps", ".pt", ".pw", ".py", ".qa", ".re", ".ro", ".ru", ".rw", ".sa", ".sb", ".sc", ".sd", ".se", ".sg", ".sh", ".si", ".sj", ".sk", ".sl", ".sm", ".sn", ".so", ".sr", ".st", ".sv", ".sy", ".sz", ".tc", ".td", ".tf", ".tg", ".th", ".tj", ".tk", ".tl", ".tm", ".tn", ".to", ".tp", ".tr", ".tt", ".tv", ".tw", ".tz", ".ua", ".ug", ".uk", ".um", ".us", ".uy", ".uz", ".va", ".vc", ".ve", ".vg", ".vi", ".vn", ".vu", ".wf", ".ws", ".ye", ".yt", ".yu", ".yr", ".za", ".zm", ".zw"} domainSubfixList := []string{".com", ".edu", ".gov", ".int", ".mil", ".net", ".org", ".biz", ".info", ".pro", ".name", ".museum", ".coop", ".aero", ".xxx", ".idv", ".ac", ".ad", ".ae", ".af", ".ag", ".ai", ".al", ".am", ".an", ".ao", ".aq", ".ar", ".as", ".at", ".au", ".aw", ".az", ".ba", ".bb", ".bd", ".be", ".bf", ".bg", ".bh", ".bi", ".bj", ".bm", ".bn", ".bo", ".br", ".bs", ".bt", ".bv", ".bw", ".by", ".bz", ".ca", ".cc", ".cd", ".cf", ".cg", ".ch", ".ci", ".ck", ".cl", ".cm", ".cn", ".co", ".cr", ".cu", ".cv", ".cx", ".cy", ".cz", ".de", ".dj", ".dk", ".dm", ".do", ".dz", ".ec", ".ee", ".eg", ".eh", ".er", ".es", ".et", ".eu", ".fi", ".fj", ".fk", ".fm", ".fo", ".fr", ".ga", ".gd", ".ge", ".gf", ".gg", ".gh", ".gi", ".gl", ".gm", ".gn", ".gp", ".gq", ".gr", ".gs", ".gt", ".gu", ".gw", ".gy", ".hk", ".hm", ".hn", ".hr", ".ht", ".hu", ".id", ".ie", ".il", ".im", ".in", ".io", ".iq", ".ir", ".is", ".it", ".je", ".jm", ".jo", ".jp", ".ke", ".kg", ".kh", ".ki", ".km", ".kn", ".kp", ".kr", ".kw", ".ky", ".kz", ".la", ".lb", ".lc", ".li", ".lk", ".lr", ".ls", ".lt", ".lu", ".lv", ".ly", ".ma", ".mc", ".md", ".mg", ".mh", ".mk", ".ml", ".mm", ".mn", ".mo", ".mp", ".mq", ".mr", ".ms", ".mt", ".mu", ".mv", ".mw", ".mx", ".my", ".mz", ".na", ".nc", ".ne", ".nf", ".ng", ".ni", ".nl", ".no", ".np", ".nr", ".nu", ".nz", ".om", ".pa", ".pe", ".pf", ".pg", ".ph", ".pk", ".pl", ".pm", ".pn", ".pr", ".ps", ".pt", ".pw", ".py", ".qa", ".re", ".ro", ".ru", ".rw", ".sa", ".sb", ".sc", ".sd", ".se", ".sg", ".sh", ".si", ".sj", ".sk", ".sl", ".sm", ".sn", ".so", ".sr", ".st", ".sv", ".sy", ".sz", ".tc", ".td", ".tf", ".tg", ".th", ".tj", ".tk", ".tl", ".tm", ".tn", ".to", ".tp", ".tr", ".tt", ".tv", ".tw", ".tz", ".ua", ".ug", ".uk", ".um", ".us", ".uy", ".uz", ".va", ".vc", ".ve", ".vg", ".vi", ".vn", ".vu", ".wf", ".ws", ".ye", ".yt", ".yu", ".yr", ".za", ".zm", ".zw"}
C := CList[int(RandInt(4))%len(CList)] C := CList[int(RandInt(4))%len(CList)]
ST := RandString(int(RandInt(4) % 10)) ST := RandString(int(RandInt(4) % 10))
O := RandString(int(RandInt(4) % 10)) O := RandString(int(RandInt(4) % 10))
CN := strings.ToLower(RandString(int(RandInt(4)%10)) + domainSubfixList[int(RandInt(4))%len(domainSubfixList)]) CN := strings.ToLower(RandString(int(RandInt(4)%10)) + domainSubfixList[int(RandInt(4))%len(domainSubfixList)])
cmdStr := fmt.Sprintf("openssl req -new -key proxy.key -x509 -days 36500 -out proxy.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, CN) log.Printf("C: %s, ST: %s, O: %s, CN: %s", C, ST, O, CN)
cmd = exec.Command("sh", "-c", cmdStr) var out []byte
out, err = cmd.CombinedOutput() if len(os.Args) == 3 && os.Args[2] == "ca" {
if err != nil { cmd := exec.Command("sh", "-c", "openssl genrsa -out ca.key 2048")
log.Printf("err:%s", err) out, err = cmd.CombinedOutput()
return if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
cmdStr := fmt.Sprintf("openssl req -new -key ca.key -x509 -days 36500 -out ca.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, "*."+CN)
cmd = exec.Command("sh", "-c", cmdStr)
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
} else if len(os.Args) == 5 && os.Args[2] == "ca" && os.Args[3] != "" && os.Args[4] != "" {
certBytes, _ := ioutil.ReadFile("ca.crt")
block, _ := pem.Decode(certBytes)
if block == nil || certBytes == nil {
panic("failed to parse ca certificate PEM")
}
x509Cert, _ := x509.ParseCertificate(block.Bytes)
if x509Cert == nil {
panic("failed to parse block")
}
name := os.Args[3]
days := os.Args[4]
cmd := exec.Command("sh", "-c", "openssl genrsa -out "+name+".key 2048")
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
cmdStr := fmt.Sprintf("openssl req -new -nodes -key %s.key -out %s.csr -days %s -subj /C=%s/ST=%s/O=%s/CN=%s", name, name, days, C, ST, O, CN)
cmd = exec.Command("sh", "-c", cmdStr)
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
cmdStr = fmt.Sprintf("openssl x509 -req -in %s.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out %s.crt", name, name)
cmd = exec.Command("sh", "-c", cmdStr)
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
} else if len(os.Args) == 2 {
cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
cmdStr := fmt.Sprintf("openssl req -new -key proxy.key -x509 -days 36500 -out proxy.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, CN)
cmd = exec.Command("sh", "-c", cmdStr)
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
} }
fmt.Println(string(out))
return return
} }
func GetAllInterfaceAddr() ([]net.IP, error) { func GetAllInterfaceAddr() ([]net.IP, error) {

View File

@ -42,8 +42,8 @@ func NewServerChannelHost(host string) ServerChannel {
func (sc *ServerChannel) SetErrAcceptHandler(fn func(err error)) { func (sc *ServerChannel) SetErrAcceptHandler(fn func(err error)) {
sc.errAcceptHandler = fn sc.errAcceptHandler = fn
} }
func (sc *ServerChannel) ListenTls(certBytes, keyBytes []byte, fn func(conn net.Conn)) (err error) { func (sc *ServerChannel) ListenTls(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) {
sc.Listener, err = ListenTls(sc.ip, sc.port, certBytes, keyBytes) sc.Listener, err = ListenTls(sc.ip, sc.port, certBytes, keyBytes, caCertBytes)
if err == nil { if err == nil {
go func() { go func() {
defer func() { defer func() {

View File

@ -491,25 +491,27 @@ func (req *HTTPRequest) addPortIfNot() (newHost string) {
} }
type OutPool struct { type OutPool struct {
Pool ConnPool Pool ConnPool
dur int dur int
typ string typ string
certBytes []byte certBytes []byte
keyBytes []byte keyBytes []byte
kcp kcpcfg.KCPConfigArgs caCertBytes []byte
address string kcp kcpcfg.KCPConfigArgs
timeout int address string
timeout int
} }
func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) { func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyBytes, caCertBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) {
op = OutPool{ op = OutPool{
dur: dur, dur: dur,
typ: typ, typ: typ,
certBytes: certBytes, certBytes: certBytes,
keyBytes: keyBytes, keyBytes: keyBytes,
kcp: kcp, caCertBytes: caCertBytes,
address: address, kcp: kcp,
timeout: timeout, address: address,
timeout: timeout,
} }
var err error var err error
op.Pool, err = NewConnPool(poolConfig{ op.Pool, err = NewConnPool(poolConfig{
@ -543,7 +545,7 @@ func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyByt
func (op *OutPool) getConn() (conn interface{}, err error) { func (op *OutPool) getConn() (conn interface{}, err error) {
if op.typ == "tls" { if op.typ == "tls" {
var _conn tls.Conn var _conn tls.Conn
_conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes) _conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes, op.caCertBytes)
if err == nil { if err == nil {
conn = net.Conn(&_conn) conn = net.Conn(&_conn)
} }