From 34e9e362b925833a45c312c3b768a74d3dbaee93 Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Mon, 12 Mar 2018 17:31:35 +0800 Subject: [PATCH] Signed-off-by: arraykeys@gmail.com --- config.go | 3 + services/args.go | 6 ++ services/http.go | 10 +- services/mux_bridge.go | 2 +- services/mux_client.go | 2 +- services/mux_server.go | 2 +- services/socks.go | 24 +++-- services/sps.go | 12 ++- services/tcp.go | 4 +- services/tunnel_bridge.go | 2 +- services/tunnel_client.go | 2 +- services/tunnel_server.go | 4 +- services/udp.go | 2 +- utils/functions.go | 187 ++++++++++++++++++++++++-------------- utils/serve-channel.go | 4 +- utils/structs.go | 36 ++++---- 16 files changed, 194 insertions(+), 108 deletions(-) diff --git a/config.go b/config.go index f980254..5d5e671 100755 --- a/config.go +++ b/config.go @@ -73,6 +73,7 @@ func initConfig() (err error) { //########http######### http := app.Command("http", "proxy on http mode") httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() + httpArgs.CaCertFile = http.Flag("ca", "ca cert file for tls").Default("").String() httpArgs.CertFile = http.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() httpArgs.KeyFile = http.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() httpArgs.LocalType = http.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp") @@ -189,6 +190,7 @@ func initConfig() (err error) { socksArgs.UDPParent = socks.Flag("udp-parent", "udp parent address, such as: \"23.32.32.19:33090\"").Default("").Short('X').String() socksArgs.UDPLocal = socks.Flag("udp-local", "udp local ip:port to listen").Short('x').Default(":33090").String() socksArgs.CertFile = socks.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() + socksArgs.CaCertFile = socks.Flag("ca", "ca cert file for tls").Default("").String() socksArgs.KeyFile = socks.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() socksArgs.SSHUser = socks.Flag("ssh-user", "user for ssh").Short('u').Default("").String() socksArgs.SSHKeyFile = socks.Flag("ssh-key", "private key file for ssh").Short('S').Default("").String() @@ -213,6 +215,7 @@ func initConfig() (err error) { spsArgs.Parent = sps.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() spsArgs.CertFile = sps.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() spsArgs.KeyFile = sps.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() + spsArgs.CaCertFile = sps.Flag("ca", "ca cert file for tls").Default("").String() spsArgs.Timeout = sps.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Short('i').Default("2000").Int() spsArgs.ParentType = sps.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "kcp") spsArgs.LocalType = sps.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp") diff --git a/services/args.go b/services/args.go index 39496e1..2786f79 100644 --- a/services/args.go +++ b/services/args.go @@ -120,6 +120,8 @@ type HTTPArgs struct { Parent *string CertFile *string KeyFile *string + CaCertFile *string + CaCertBytes []byte CertBytes []byte KeyBytes []byte Local *string @@ -169,6 +171,8 @@ type SocksArgs struct { LocalType *string CertFile *string KeyFile *string + CaCertFile *string + CaCertBytes []byte CertBytes []byte KeyBytes []byte SSHKeyFile *string @@ -199,6 +203,8 @@ type SPSArgs struct { Parent *string CertFile *string KeyFile *string + CaCertFile *string + CaCertBytes []byte CertBytes []byte KeyBytes []byte Local *string diff --git a/services/http.go b/services/http.go index 8621a8a..9587303 100644 --- a/services/http.go +++ b/services/http.go @@ -41,6 +41,12 @@ func (s *HTTP) CheckArgs() { } if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" { s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) + if *s.cfg.CaCertFile != "" { + s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile) + if err != nil { + log.Fatalf("read ca file error,ERR:%s", err) + } + } } if *s.cfg.ParentType == "ssh" { if *s.cfg.SSHUser == "" { @@ -128,7 +134,7 @@ func (s *HTTP) Start(args interface{}) (err error) { if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.callback) } else if *s.cfg.LocalType == TYPE_TLS { - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes, s.callback) } else if *s.cfg.LocalType == TYPE_KCP { err = sc.ListenKCP(s.cfg.KCP, s.callback) } @@ -321,7 +327,7 @@ func (s *HTTP) InitOutConnPool() { *s.cfg.CheckParentInterval, *s.cfg.ParentType, s.cfg.KCP, - s.cfg.CertBytes, s.cfg.KeyBytes, + s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes, s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, *s.cfg.PoolSize, diff --git a/services/mux_bridge.go b/services/mux_bridge.go index adf361d..37e52c7 100644 --- a/services/mux_bridge.go +++ b/services/mux_bridge.go @@ -56,7 +56,7 @@ func (s *MuxBridge) Start(args interface{}) (err error) { if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.handler) } else if *s.cfg.LocalType == TYPE_TLS { - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.handler) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.handler) } else if *s.cfg.LocalType == TYPE_KCP { err = sc.ListenKCP(s.cfg.KCP, s.handler) } diff --git a/services/mux_client.go b/services/mux_client.go index 79af08c..56b6bbf 100644 --- a/services/mux_client.go +++ b/services/mux_client.go @@ -126,7 +126,7 @@ func (s *MuxClient) Clean() { func (s *MuxClient) getParentConn() (conn net.Conn, err error) { if *s.cfg.ParentType == "tls" { var _conn tls.Conn - _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) if err == nil { conn = net.Conn(&_conn) } diff --git a/services/mux_server.go b/services/mux_server.go index 643443b..69e47b1 100644 --- a/services/mux_server.go +++ b/services/mux_server.go @@ -290,7 +290,7 @@ func (s *MuxServer) GetConn(index string) (conn net.Conn, err error) { func (s *MuxServer) getParentConn() (conn net.Conn, err error) { if *s.cfg.ParentType == "tls" { var _conn tls.Conn - _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) if err == nil { conn = net.Conn(&_conn) } diff --git a/services/socks.go b/services/socks.go index b69ed91..bea8180 100644 --- a/services/socks.go +++ b/services/socks.go @@ -37,16 +37,28 @@ func NewSocks() Service { func (s *Socks) CheckArgs() { var err error - if *s.cfg.LocalType == "tls" { + if *s.cfg.LocalType == "tls" || (*s.cfg.Parent != "" && *s.cfg.ParentType == "tls") { s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) + if *s.cfg.CaCertFile != "" { + s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile) + if err != nil { + log.Fatalf("read ca file error,ERR:%s", err) + } + } } if *s.cfg.Parent != "" { if *s.cfg.ParentType == "" { log.Fatalf("parent type unkown,use -T ") } - if *s.cfg.ParentType == "tls" { - s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) - } + // if *s.cfg.ParentType == "tls" { + // s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) + // if *s.cfg.CaCertFile != "" { + // s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile) + // if err != nil { + // log.Fatalf("read ca file error,ERR:%s", err) + // } + // } + // } if *s.cfg.ParentType == "ssh" { if *s.cfg.SSHUser == "" { log.Fatalf("ssh user required") @@ -138,7 +150,7 @@ func (s *Socks) Start(args interface{}) (err error) { if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.socksConnCallback) } else if *s.cfg.LocalType == TYPE_TLS { - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.socksConnCallback) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.socksConnCallback) } else if *s.cfg.LocalType == TYPE_KCP { err = sc.ListenKCP(s.cfg.KCP, s.socksConnCallback) } @@ -471,7 +483,7 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n case "tcp": if *s.cfg.ParentType == "tls" { var _outConn tls.Conn - _outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + _outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) outConn = net.Conn(&_outConn) } else if *s.cfg.ParentType == "kcp" { outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), s.cfg.KCP) diff --git a/services/sps.go b/services/sps.go index 1e4e4a5..1b5de03 100644 --- a/services/sps.go +++ b/services/sps.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io/ioutil" "log" "net" "runtime/debug" @@ -34,6 +35,13 @@ func (s *SPS) CheckArgs() { } if *s.cfg.ParentType == TYPE_TLS || *s.cfg.LocalType == TYPE_TLS { s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) + if *s.cfg.CaCertFile != "" { + var err error + s.cfg.CaCertBytes, err = ioutil.ReadFile(*s.cfg.CaCertFile) + if err != nil { + log.Fatalf("read ca file error,ERR:%s", err) + } + } } } func (s *SPS) InitService() { @@ -47,7 +55,7 @@ func (s *SPS) InitOutConnPool() { 0, *s.cfg.ParentType, s.cfg.KCP, - s.cfg.CertBytes, s.cfg.KeyBytes, + s.cfg.CertBytes, s.cfg.KeyBytes, nil, *s.cfg.Parent, *s.cfg.Timeout, 0, @@ -75,7 +83,7 @@ func (s *SPS) Start(args interface{}) (err error) { if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.callback) } else if *s.cfg.LocalType == TYPE_TLS { - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback) } else if *s.cfg.LocalType == TYPE_KCP { err = sc.ListenKCP(s.cfg.KCP, s.callback) } diff --git a/services/tcp.go b/services/tcp.go index 85136a2..34df8ee 100644 --- a/services/tcp.go +++ b/services/tcp.go @@ -56,7 +56,7 @@ func (s *TCP) Start(args interface{}) (err error) { if *s.cfg.LocalType == TYPE_TCP { err = sc.ListenTCP(s.callback) } else if *s.cfg.LocalType == TYPE_TLS { - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback) } else if *s.cfg.LocalType == TYPE_KCP { err = sc.ListenKCP(s.cfg.KCP, s.callback) } @@ -172,7 +172,7 @@ func (s *TCP) InitOutConnPool() { *s.cfg.CheckParentInterval, *s.cfg.ParentType, s.cfg.KCP, - s.cfg.CertBytes, s.cfg.KeyBytes, + s.cfg.CertBytes, s.cfg.KeyBytes, nil, *s.cfg.Parent, *s.cfg.Timeout, *s.cfg.PoolSize, diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index bb048dd..ddae0a7 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -51,7 +51,7 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { p, _ := strconv.Atoi(port) sc := utils.NewServerChannel(host, p) - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) { + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, func(inConn net.Conn) { //log.Printf("connection from %s ", inConn.RemoteAddr()) reader := bufio.NewReader(inConn) diff --git a/services/tunnel_client.go b/services/tunnel_client.go index 94daf95..0d93483 100644 --- a/services/tunnel_client.go +++ b/services/tunnel_client.go @@ -161,7 +161,7 @@ func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, e } func (s *TunnelClient) GetConn() (conn net.Conn, err error) { var _conn tls.Conn - _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) if err == nil { conn = net.Conn(&_conn) } diff --git a/services/tunnel_server.go b/services/tunnel_server.go index 9ef3763..b4fbfa0 100644 --- a/services/tunnel_server.go +++ b/services/tunnel_server.go @@ -174,7 +174,7 @@ func (s *TunnelServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string } func (s *TunnelServerManager) GetConn() (conn net.Conn, err error) { var _conn tls.Conn - _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) if err == nil { conn = net.Conn(&_conn) } @@ -280,7 +280,7 @@ func (s *TunnelServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err e } func (s *TunnelServer) GetConn() (conn net.Conn, err error) { var _conn tls.Conn - _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes) + _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil) if err == nil { conn = net.Conn(&_conn) } diff --git a/services/udp.go b/services/udp.go index 6f63f0d..74951fb 100644 --- a/services/udp.go +++ b/services/udp.go @@ -210,7 +210,7 @@ func (s *UDP) InitOutConnPool() { *s.cfg.CheckParentInterval, *s.cfg.ParentType, kcpcfg.KCPConfigArgs{}, - s.cfg.CertBytes, s.cfg.KeyBytes, + s.cfg.CertBytes, s.cfg.KeyBytes, nil, *s.cfg.Parent, *s.cfg.Timeout, *s.cfg.PoolSize, diff --git a/utils/functions.go b/utils/functions.go index 3927909..9d75e53 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -86,14 +86,14 @@ func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) { } } } -func TlsConnectHost(host string, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) { +func TlsConnectHost(host string, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) { h := strings.Split(host, ":") port, _ := strconv.Atoi(h[1]) - return TlsConnect(h[0], port, timeout, certBytes, keyBytes) + return TlsConnect(h[0], port, timeout, certBytes, keyBytes, caCertBytes) } -func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) { - conf, err := getRequestTlsConfig(certBytes, keyBytes) +func TlsConnect(host string, port, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) { + conf, err := getRequestTlsConfig(certBytes, keyBytes, caCertBytes) if err != nil { return } @@ -103,8 +103,24 @@ func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (con } return *tls.Client(_conn, conf), err } -func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err error) { - block, _ := pem.Decode(certBytes) +func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) { + + var cert tls.Certificate + cert, err = tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + return + } + serverCertPool := x509.NewCertPool() + caBytes := certBytes + if caCertBytes != nil { + caBytes = caCertBytes + + } + ok := serverCertPool.AppendCertsFromPEM(caBytes) + if !ok { + err = errors.New("failed to parse root certificate") + } + block, _ := pem.Decode(caBytes) if block == nil { panic("failed to parse certificate PEM") } @@ -112,34 +128,24 @@ func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err erro if x509Cert == nil { panic("failed to parse block") } - var cert tls.Certificate - cert, err = tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - return - } - serverCertPool := x509.NewCertPool() - ok := serverCertPool.AppendCertsFromPEM(certBytes) - if !ok { - err = errors.New("failed to parse root certificate") - } conf = &tls.Config{ RootCAs: serverCertPool, Certificates: []tls.Certificate{cert}, - InsecureSkipVerify: false, + InsecureSkipVerify: true, ServerName: x509Cert.Subject.CommonName, - // VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - // opts := x509.VerifyOptions{ - // Roots: serverCertPool, - // } - // for _, rawCert := range rawCerts { - // cert, _ := x509.ParseCertificate(rawCert) - // _, err := cert.Verify(opts) - // if err != nil { - // return err - // } - // } - // return nil - // }, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + opts := x509.VerifyOptions{ + Roots: serverCertPool, + } + for _, rawCert := range rawCerts { + cert, _ := x509.ParseCertificate(rawCert) + _, err := cert.Verify(opts) + if err != nil { + return err + } + } + return nil + }, } return } @@ -165,22 +171,19 @@ func ConnectKCPHost(hostAndPort string, config kcpcfg.KCPConfigArgs) (conn net.C return NewCompStream(kcpconn), err } -func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listener, err error) { - block, _ := pem.Decode(certBytes) - if block == nil { - panic("failed to parse certificate PEM") - } - x509Cert, _ := x509.ParseCertificate(block.Bytes) - if x509Cert == nil { - panic("failed to parse block") - } +func ListenTls(ip string, port int, certBytes, keyBytes, caCertBytes []byte) (ln *net.Listener, err error) { + var cert tls.Certificate cert, err = tls.X509KeyPair(certBytes, keyBytes) if err != nil { return } clientCertPool := x509.NewCertPool() - ok := clientCertPool.AppendCertsFromPEM(certBytes) + caBytes := certBytes + if caCertBytes != nil { + caBytes = caCertBytes + } + ok := clientCertPool.AppendCertsFromPEM(caBytes) if !ok { err = errors.New("failed to parse root certificate") } @@ -188,21 +191,6 @@ func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listene ClientCAs: clientCertPool, Certificates: []tls.Certificate{cert}, ClientAuth: tls.RequireAndVerifyClientCert, - ServerName: x509Cert.Subject.CommonName, - // VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - // opts := x509.VerifyOptions{ - // Roots: clientCertPool, - // } - // for _, rawCert := range rawCerts { - // cert, _ := x509.ParseCertificate(rawCert) - // _, err := cert.Verify(opts) - // fmt.Println("SERVER ERR:", err) - // if err != nil { - // return err - // } - // } - // return nil - // }, } _ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config) if err == nil { @@ -245,27 +233,88 @@ func CloseConn(conn *net.Conn) { } } func Keygen() (err error) { - cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048") - out, err := cmd.CombinedOutput() - if err != nil { - log.Printf("err:%s", err) - return - } - fmt.Println(string(out)) CList := []string{"AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BL", "BM", "BN", "BO", "BR", "BS", "BW", "BY", "BZ", "CA", "CF", "CG", "CH", "CK", "CL", "CM", "CN", "CO", "CR", "CS", "CU", "CY", "CZ", "DE", "DJ", "DK", "DO", "DZ", "EC", "EE", "EG", "ES", "ET", "FI", "FJ", "FR", "GA", "GB", "GD", "GE", "GF", "GH", "GI", "GM", "GN", "GR", "GT", "GU", "GY", "HK", "HN", "HT", "HU", "ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG", "KH", "KP", "KR", "KT", "KW", "KZ", "LA", "LB", "LC", "LI", "LK", "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "MG", "ML", "MM", "MN", "MO", "MS", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE", "NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PF", "PG", "PH", "PK", "PL", "PR", "PT", "PY", "QA", "RO", "RU", "SA", "SB", "SC", "SD", "SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ", "TD", "TG", "TH", "TJ", "TM", "TN", "TO", "TR", "TT", "TW", "TZ", "UA", "UG", "US", "UY", "UZ", "VC", "VE", "VN", "YE", "YU", "ZA", "ZM", "ZR", "ZW"} domainSubfixList := []string{".com", ".edu", ".gov", ".int", ".mil", ".net", ".org", ".biz", ".info", ".pro", ".name", ".museum", ".coop", ".aero", ".xxx", ".idv", ".ac", ".ad", ".ae", ".af", ".ag", ".ai", ".al", ".am", ".an", ".ao", ".aq", ".ar", ".as", ".at", ".au", ".aw", ".az", ".ba", ".bb", ".bd", ".be", ".bf", ".bg", ".bh", ".bi", ".bj", ".bm", ".bn", ".bo", ".br", ".bs", ".bt", ".bv", ".bw", ".by", ".bz", ".ca", ".cc", ".cd", ".cf", ".cg", ".ch", ".ci", ".ck", ".cl", ".cm", ".cn", ".co", ".cr", ".cu", ".cv", ".cx", ".cy", ".cz", ".de", ".dj", ".dk", ".dm", ".do", ".dz", ".ec", ".ee", ".eg", ".eh", ".er", ".es", ".et", ".eu", ".fi", ".fj", ".fk", ".fm", ".fo", ".fr", ".ga", ".gd", ".ge", ".gf", ".gg", ".gh", ".gi", ".gl", ".gm", ".gn", ".gp", ".gq", ".gr", ".gs", ".gt", ".gu", ".gw", ".gy", ".hk", ".hm", ".hn", ".hr", ".ht", ".hu", ".id", ".ie", ".il", ".im", ".in", ".io", ".iq", ".ir", ".is", ".it", ".je", ".jm", ".jo", ".jp", ".ke", ".kg", ".kh", ".ki", ".km", ".kn", ".kp", ".kr", ".kw", ".ky", ".kz", ".la", ".lb", ".lc", ".li", ".lk", ".lr", ".ls", ".lt", ".lu", ".lv", ".ly", ".ma", ".mc", ".md", ".mg", ".mh", ".mk", ".ml", ".mm", ".mn", ".mo", ".mp", ".mq", ".mr", ".ms", ".mt", ".mu", ".mv", ".mw", ".mx", ".my", ".mz", ".na", ".nc", ".ne", ".nf", ".ng", ".ni", ".nl", ".no", ".np", ".nr", ".nu", ".nz", ".om", ".pa", ".pe", ".pf", ".pg", ".ph", ".pk", ".pl", ".pm", ".pn", ".pr", ".ps", ".pt", ".pw", ".py", ".qa", ".re", ".ro", ".ru", ".rw", ".sa", ".sb", ".sc", ".sd", ".se", ".sg", ".sh", ".si", ".sj", ".sk", ".sl", ".sm", ".sn", ".so", ".sr", ".st", ".sv", ".sy", ".sz", ".tc", ".td", ".tf", ".tg", ".th", ".tj", ".tk", ".tl", ".tm", ".tn", ".to", ".tp", ".tr", ".tt", ".tv", ".tw", ".tz", ".ua", ".ug", ".uk", ".um", ".us", ".uy", ".uz", ".va", ".vc", ".ve", ".vg", ".vi", ".vn", ".vu", ".wf", ".ws", ".ye", ".yt", ".yu", ".yr", ".za", ".zm", ".zw"} C := CList[int(RandInt(4))%len(CList)] ST := RandString(int(RandInt(4) % 10)) O := RandString(int(RandInt(4) % 10)) CN := strings.ToLower(RandString(int(RandInt(4)%10)) + domainSubfixList[int(RandInt(4))%len(domainSubfixList)]) - cmdStr := fmt.Sprintf("openssl req -new -key proxy.key -x509 -days 36500 -out proxy.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, CN) - cmd = exec.Command("sh", "-c", cmdStr) - out, err = cmd.CombinedOutput() - if err != nil { - log.Printf("err:%s", err) - return + log.Printf("C: %s, ST: %s, O: %s, CN: %s", C, ST, O, CN) + var out []byte + if len(os.Args) == 3 && os.Args[2] == "ca" { + cmd := exec.Command("sh", "-c", "openssl genrsa -out ca.key 2048") + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return + } + fmt.Println(string(out)) + + cmdStr := fmt.Sprintf("openssl req -new -key ca.key -x509 -days 36500 -out ca.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, "*."+CN) + cmd = exec.Command("sh", "-c", cmdStr) + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return + } + fmt.Println(string(out)) + } else if len(os.Args) == 5 && os.Args[2] == "ca" && os.Args[3] != "" && os.Args[4] != "" { + certBytes, _ := ioutil.ReadFile("ca.crt") + block, _ := pem.Decode(certBytes) + if block == nil || certBytes == nil { + panic("failed to parse ca certificate PEM") + } + x509Cert, _ := x509.ParseCertificate(block.Bytes) + if x509Cert == nil { + panic("failed to parse block") + } + name := os.Args[3] + days := os.Args[4] + cmd := exec.Command("sh", "-c", "openssl genrsa -out "+name+".key 2048") + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return + } + fmt.Println(string(out)) + + cmdStr := fmt.Sprintf("openssl req -new -nodes -key %s.key -out %s.csr -days %s -subj /C=%s/ST=%s/O=%s/CN=%s", name, name, days, C, ST, O, CN) + cmd = exec.Command("sh", "-c", cmdStr) + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return + } + fmt.Println(string(out)) + + cmdStr = fmt.Sprintf("openssl x509 -req -in %s.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out %s.crt", name, name) + cmd = exec.Command("sh", "-c", cmdStr) + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return + } + + fmt.Println(string(out)) + } else if len(os.Args) == 2 { + cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048") + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return + } + fmt.Println(string(out)) + + cmdStr := fmt.Sprintf("openssl req -new -key proxy.key -x509 -days 36500 -out proxy.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, CN) + cmd = exec.Command("sh", "-c", cmdStr) + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return + } + fmt.Println(string(out)) } - fmt.Println(string(out)) + return } func GetAllInterfaceAddr() ([]net.IP, error) { diff --git a/utils/serve-channel.go b/utils/serve-channel.go index a53f00e..596cede 100644 --- a/utils/serve-channel.go +++ b/utils/serve-channel.go @@ -42,8 +42,8 @@ func NewServerChannelHost(host string) ServerChannel { func (sc *ServerChannel) SetErrAcceptHandler(fn func(err error)) { sc.errAcceptHandler = fn } -func (sc *ServerChannel) ListenTls(certBytes, keyBytes []byte, fn func(conn net.Conn)) (err error) { - sc.Listener, err = ListenTls(sc.ip, sc.port, certBytes, keyBytes) +func (sc *ServerChannel) ListenTls(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) { + sc.Listener, err = ListenTls(sc.ip, sc.port, certBytes, keyBytes, caCertBytes) if err == nil { go func() { defer func() { diff --git a/utils/structs.go b/utils/structs.go index a88e75c..9528e4f 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -491,25 +491,27 @@ func (req *HTTPRequest) addPortIfNot() (newHost string) { } type OutPool struct { - Pool ConnPool - dur int - typ string - certBytes []byte - keyBytes []byte - kcp kcpcfg.KCPConfigArgs - address string - timeout int + Pool ConnPool + dur int + typ string + certBytes []byte + keyBytes []byte + caCertBytes []byte + kcp kcpcfg.KCPConfigArgs + address string + timeout int } -func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) { +func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyBytes, caCertBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) { op = OutPool{ - dur: dur, - typ: typ, - certBytes: certBytes, - keyBytes: keyBytes, - kcp: kcp, - address: address, - timeout: timeout, + dur: dur, + typ: typ, + certBytes: certBytes, + keyBytes: keyBytes, + caCertBytes: caCertBytes, + kcp: kcp, + address: address, + timeout: timeout, } var err error op.Pool, err = NewConnPool(poolConfig{ @@ -543,7 +545,7 @@ func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyByt func (op *OutPool) getConn() (conn interface{}, err error) { if op.typ == "tls" { var _conn tls.Conn - _conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes) + _conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes, op.caCertBytes) if err == nil { conn = net.Conn(&_conn) }