From e28d5449b515733189a36416a49cd2edc57f9030 Mon Sep 17 00:00:00 2001 From: "arraykeys@gmail.com" Date: Wed, 18 Oct 2017 18:01:53 +0800 Subject: [PATCH] Signed-off-by: arraykeys@gmail.com --- CHANGELOG | 3 + services/socks.go | 303 ++++++++++++++++++++++++++++---------- services/tunnel_bridge.go | 38 +++-- services/tunnel_client.go | 24 +-- utils/aes/aes.go | 84 +++++++++++ utils/functions.go | 3 + utils/socks/structs.go | 254 ++++++++++++++++++++++++++++++++ 7 files changed, 603 insertions(+), 106 deletions(-) create mode 100644 utils/aes/aes.go create mode 100644 utils/socks/structs.go diff --git a/CHANGELOG b/CHANGELOG index 5122d0b..1bc9767 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,4 +1,7 @@ proxy更新日志 +v3.3 +1.修复了socks代理模式对证书文件的判断逻辑. + v3.2 1.内网穿透功能server端-r参数增加了协议和key设置. 2.手册增加了对-r参数的详细说明. diff --git a/services/socks.go b/services/socks.go index ee636ad..5a2cc1b 100644 --- a/services/socks.go +++ b/services/socks.go @@ -1,15 +1,16 @@ package services import ( - "bytes" "crypto/tls" - "encoding/binary" "fmt" "io" "io/ioutil" "log" "net" "proxy/utils" + "proxy/utils/aes" + "proxy/utils/socks" + "runtime/debug" "time" "golang.org/x/crypto/ssh" @@ -20,6 +21,8 @@ type Socks struct { checker utils.Checker basicAuth utils.BasicAuth sshClient *ssh.Client + lockChn chan bool + udpSC utils.ServerChannel } func NewSocks() Service { @@ -27,16 +30,22 @@ func NewSocks() Service { cfg: SocksArgs{}, checker: utils.Checker{}, basicAuth: utils.BasicAuth{}, + lockChn: make(chan bool, 1), } } func (s *Socks) CheckArgs() { var err error + if *s.cfg.LocalType == "tls" { + log.Println(*s.cfg.CertFile, *s.cfg.KeyFile) + s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) + } if *s.cfg.Parent != "" { if *s.cfg.ParentType == "" { log.Fatalf("parent type unkown,use -T ") } - if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" { + if *s.cfg.ParentType == "tls" { + log.Println(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) } if *s.cfg.ParentType == "ssh" { @@ -77,11 +86,25 @@ func (s *Socks) InitService() { log.Fatalf("init service fail, ERR: %s", err) } } + if *s.cfg.ParentType == "ssh" { + log.Println("warn: socks udp not suppored for ssh") + } else { + _, port, _ := net.SplitHostPort(*s.cfg.Local) + s.udpSC = utils.NewServerChannelHost(":" + port) + err := s.udpSC.ListenUDP(s.udpCallback) + if err != nil { + log.Fatalf("init udp service fail, ERR: %s", err) + } + log.Printf("udp socks proxy on %s", s.udpSC.UDPListener.LocalAddr()) + } } func (s *Socks) StopService() { if s.sshClient != nil { s.sshClient.Close() } + if s.udpSC.UDPListener != nil { + s.udpSC.UDPListener.Close() + } } func (s *Socks) Start(args interface{}) (err error) { //start() @@ -93,9 +116,9 @@ func (s *Socks) Start(args interface{}) (err error) { } sc := utils.NewServerChannelHost(*s.cfg.Local) if *s.cfg.LocalType == TYPE_TCP { - err = sc.ListenTCP(s.callback) + err = sc.ListenTCP(s.socksConnCallback) } else { - err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.socksConnCallback) } if err != nil { return @@ -106,95 +129,221 @@ func (s *Socks) Start(args interface{}) (err error) { func (s *Socks) Clean() { s.StopService() } -func (s *Socks) callback(inConn net.Conn) { +func (s *Socks) UDPKey() []byte { + return s.cfg.KeyBytes[:32] +} +func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) { + newB := b + var err error + if *s.cfg.LocalType == "tls" { + //decode b + newB, err = goaes.Decrypt(s.UDPKey(), b) + if err != nil { + log.Printf("decrypt udp packet fail from %s", srcAddr.String()) + return + } + } + p, err := socks.ParseUDPPacket(newB) + log.Printf("udp revecived:%v", len(p.Data())) + if err != nil { + log.Printf("parse udp packet fail, ERR:%s", err) + return + } + //log.Printf("##########udp to -> %s:%s###########", p.Host(), p.Port()) + if *s.cfg.Parent != "" { + //有上级代理,转发给上级 + if *s.cfg.ParentType == "tls" { + //encode b + newB, err = goaes.Encrypt(s.UDPKey(), newB) + if err != nil { + log.Printf("encrypt udp data fail to %s", *s.cfg.Parent) + return + } + } + dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Parent) + if err != nil { + log.Printf("can't resolve address: %s", err) + return + } + clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) + if err != nil { + log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) + return + } + conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*2))) + _, err = conn.Write(newB) + log.Printf("udp request:%v", len(newB)) + if err != nil { + log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) + return + } + + //log.Printf("send udp packet to %s success", dstAddr.String()) + buf := make([]byte, 1024) + length, _, err := conn.ReadFromUDP(buf) + if err != nil { + log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) + return + } + respBody := buf[0:length] + log.Printf("udp response:%v", len(respBody)) + //log.Printf("revecived udp packet from %s", dstAddr.String()) + if *s.cfg.ParentType == "tls" { + //decode b + respBody, err = goaes.Decrypt(s.UDPKey(), respBody) + if err != nil { + log.Printf("encrypt udp data fail to %s", *s.cfg.Parent) + return + } + } + if *s.cfg.LocalType == "tls" { + d, err := goaes.Encrypt(s.UDPKey(), respBody) + if err != nil { + log.Printf("encrypt udp data fail from %s", dstAddr.String()) + return + } + s.udpSC.UDPListener.WriteToUDP(d, srcAddr) + log.Printf("udp reply:%v", len(d)) + } else { + s.udpSC.UDPListener.WriteToUDP(respBody, srcAddr) + log.Printf("udp reply:%v", len(respBody)) + } + + } else { + //本地代理 + dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port())) + if err != nil { + log.Printf("can't resolve address: %s", err) + return + } + clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr) + if err != nil { + log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err) + return + } + conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*2))) + _, err = conn.Write(p.Data()) + log.Printf("udp send:%v", len(p.Data())) + if err != nil { + log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err) + return + } + log.Printf("send udp packet to %s success", dstAddr.String()) + buf := make([]byte, 1024) + length, _, err := conn.ReadFromUDP(buf) + if err != nil { + log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err) + return + } + respBody := buf[0:length] + //log.Printf("revecived udp packet from %s", dstAddr.String()) + if *s.cfg.LocalType == "tls" { + d, err := goaes.Encrypt(s.UDPKey(), respBody) + if err != nil { + log.Printf("encrypt udp data fail from %s", dstAddr.String()) + return + } + s.udpSC.UDPListener.WriteToUDP(d, srcAddr) + } else { + s.udpSC.UDPListener.WriteToUDP(respBody, srcAddr) + } + log.Printf("udp reply:%v", len(respBody)) + } + +} +func (s *Socks) socksConnCallback(inConn net.Conn) { defer func() { if err := recover(); err != nil { - //log.Printf("socks conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) + log.Printf("socks conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) } utils.CloseConn(&inConn) }() - var outConn net.Conn - defer utils.CloseConn(&outConn) - var b [1024]byte - n, err := inConn.Read(b[:]) - if err != nil { - if err != io.EOF { - log.Printf("read request data fail,ERR: %s", err) - } + //method select request + methodReq, err := socks.NewMethodsRequest(inConn) + if err != nil || !methodReq.Select(socks.Method_NO_AUTH) { + methodReq.Reply(socks.Method_NONE_ACCEPTABLE) + utils.CloseConn(&inConn) return } - var reqBytes = b[:n] - //log.Printf("% x", b[:n]) - - //reply - n, err = inConn.Write([]byte{0x05, 0x00}) + //method select reply + err = methodReq.Reply(socks.Method_NO_AUTH) if err != nil { log.Printf("reply answer data fail,ERR: %s", err) + utils.CloseConn(&inConn) return } - //read answer - n, err = inConn.Read(b[:]) + // log.Printf("% x", methodReq.Bytes()) + + //request detail + request, err := socks.NewRequest(inConn) if err != nil { - log.Printf("read answer data fail,ERR: %s", err) + log.Printf("read request data fail,ERR: %s", err) + utils.CloseConn(&inConn) return } - var headBytes = b[:n] - // log.Printf("% x", b[:n]) - var addr string - switch b[3] { - case 0x01: - sip := sockIP{} - if err := binary.Read(bytes.NewReader(b[4:n]), binary.BigEndian, &sip); err != nil { - log.Printf("read ip fail,ERR: %s", err) - return - } - addr = sip.toAddr() - case 0x03: - host := string(b[5 : n-2]) - var port uint16 - err = binary.Read(bytes.NewReader(b[n-2:n]), binary.BigEndian, &port) - if err != nil { - log.Printf("read domain fail,ERR: %s", err) - return - } - addr = fmt.Sprintf("%s:%d", host, port) + + switch request.CMD() { + case socks.CMD_BIND: + //bind 不支持 + request.TCPReply(socks.REP_UNKNOWN) + utils.CloseConn(&inConn) + return + case socks.CMD_CONNECT: + //tcp + s.proxyTCP(&inConn, methodReq, request) + case socks.CMD_ASSOCIATE: + //udp + s.proxyUDP(&inConn, methodReq, request) } + +} +func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) { + if *s.cfg.ParentType == "ssh" { + return + } + host, _, _ := net.SplitHostPort((*inConn).LocalAddr().String()) + _, port, _ := net.SplitHostPort(s.udpSC.UDPListener.LocalAddr().String()) + // log.Printf("proxy udp on %s", net.JoinHostPort(host, port)) + request.UDPReply(socks.REP_SUCCESS, net.JoinHostPort(host, port)) + // log.Printf("%v", request.NewReply(socks.REP_SUCCESS, net.JoinHostPort(host, port))) +} +func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) { + var outConn net.Conn + defer utils.CloseConn(&outConn) + var err error useProxy := true if *s.cfg.Always { - outConn, err = s.getOutConn(reqBytes, headBytes, addr) + outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr()) } else { if *s.cfg.Parent != "" { - s.checker.Add(addr, true, "", "", nil) - useProxy, _, _ = s.checker.IsBlocked(addr) + s.checker.Add(request.Addr(), true, "", "", nil) + useProxy, _, _ = s.checker.IsBlocked(request.Addr()) if useProxy { - outConn, err = s.getOutConn(reqBytes, headBytes, addr) + outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr()) } else { - outConn, err = utils.ConnectHost(addr, *s.cfg.Timeout) + outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) } } else { - outConn, err = utils.ConnectHost(addr, *s.cfg.Timeout) + outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout) } } if err != nil { log.Printf("get out conn fail,%s", err) - inConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + request.TCPReply(socks.REP_NETWOR_UNREACHABLE) return } - log.Printf("use proxy %v : %s", useProxy, addr) + log.Printf("use proxy %v : %s", useProxy, request.Addr()) - inConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + request.TCPReply(socks.REP_SUCCESS) + inAddr := (*inConn).RemoteAddr().String() + inLocalAddr := (*inConn).LocalAddr().String() - inAddr := inConn.RemoteAddr().String() - inLocalAddr := inConn.LocalAddr().String() - - log.Printf("conn %s - %s connected [%s]", inAddr, inLocalAddr, addr) - // utils.IoBind(outConn, inConn, func(err error) { - // log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, addr) - - // }, func(i int, b bool) {}, 0) + log.Printf("conn %s - %s connected [%s]", inAddr, inLocalAddr, request.Addr()) var bind = func() (err interface{}) { defer func() { if err == nil { @@ -211,17 +360,17 @@ func (s *Socks) callback(inConn net.Conn) { } } }() - _, err = io.Copy(outConn, inConn) + _, err = io.Copy(outConn, (*inConn)) }() - _, err = io.Copy(inConn, outConn) + _, err = io.Copy((*inConn), outConn) return } bind() - log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, addr) - utils.CloseConn(&inConn) + log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, request.Addr()) + utils.CloseConn(inConn) utils.CloseConn(&outConn) } -func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net.Conn, err error) { +func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn net.Conn, err error) { switch *s.cfg.ParentType { case "tls": fallthrough @@ -238,7 +387,7 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net } var buf = make([]byte, 1024) //var n int - _, err = outConn.Write(reqBytes) + _, err = outConn.Write(methodBytes) if err != nil { return } @@ -249,7 +398,7 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net //resp := buf[:n] //log.Printf("resp:%v", resp) - outConn.Write(headBytes) + outConn.Write(reqBytes) _, err = outConn.Read(buf) if err != nil { return @@ -267,7 +416,6 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net outConn, err = s.sshClient.Dial("tcp", host) if err != nil { log.Printf("connect ssh fail, ERR: %s, retrying...", err) - s.sshClient.Close() e := s.ConnectSSH() if e == nil { tryCount++ @@ -282,6 +430,12 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net return } func (s *Socks) ConnectSSH() (err error) { + select { + case s.lockChn <- true: + default: + err = fmt.Errorf("can not connect at same time") + return + } config := ssh.ClientConfig{ Timeout: time.Duration(*s.cfg.Timeout) * time.Millisecond, User: *s.cfg.SSHUser, @@ -290,15 +444,10 @@ func (s *Socks) ConnectSSH() (err error) { return nil }, } + if s.sshClient != nil { + s.sshClient.Close() + } s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config) + <-s.lockChn return } - -type sockIP struct { - A, B, C, D byte - PORT uint16 -} - -func (ip sockIP) toAddr() string { - return fmt.Sprintf("%d.%d.%d.%d:%d", ip.A, ip.B, ip.C, ip.D, ip.PORT) -} diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index 683ca91..ab9bd57 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -13,7 +13,6 @@ import ( type ServerConn struct { ClientLocalAddr string //tcp:2.2.22:333@ID Conn *net.Conn - //Conn *utils.HeartbeatReadWriter } type TunnelBridge struct { cfg TunnelBridgeArgs @@ -78,7 +77,6 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { return } key = string(_key) - //log.Printf("conn key %s", key) if connType != CONN_CONTROL { var IDLength uint16 @@ -117,13 +115,8 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { switch connType { case CONN_SERVER: - // hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { - // log.Printf("%s conn %s from server released", key, ID) - // s.serverConns.Remove(ID) - // }) addr := clientLocalAddr + "@" + ID s.serverConns.Set(ID, ServerConn{ - //Conn: &hb, Conn: &inConn, ClientLocalAddr: addr, }) @@ -134,7 +127,9 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { time.Sleep(time.Second * 3) continue } + (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) _, err := (*item.(*net.Conn)).Write([]byte(addr)) + (*item.(*net.Conn)).SetWriteDeadline(time.Time{}) if err != nil { log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) time.Sleep(time.Second * 3) @@ -151,33 +146,36 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { return } serverConn := serverConnItem.(ServerConn).Conn - // hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) { - // log.Printf("%s conn %s from client released", key, ID) - // hw.Close() - // }) utils.IoBind(*serverConn, inConn, func(err error) { - // utils.IoBind(serverConn, inConn, func(isSrcErr bool, err error) { - //serverConn.Close() + (*serverConn).Close() utils.CloseConn(&inConn) - // hw.Close() s.serverConns.Remove(ID) log.Printf("conn %s released", ID) }, func(i int, b bool) {}, 0) log.Printf("conn %s created", ID) + case CONN_CONTROL: if s.clientControlConns.Has(key) { item, _ := s.clientControlConns.Get(key) - //(*item.(*utils.HeartbeatReadWriter)).Close() (*item.(*net.Conn)).Close() } - // hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { - // log.Printf("client %s disconnected", key) - // s.clientControlConns.Remove(key) - // }) - // s.clientControlConns.Set(key, &hb) s.clientControlConns.Set(key, &inConn) log.Printf("set client %s control conn", key) + go func() { + for { + var b = make([]byte, 1) + _, err = inConn.Read(b) + if err != nil { + inConn.Close() + s.serverConns.Remove(ID) + log.Printf("%s control conn from client released", key) + break + } else { + //log.Printf("%s heartbeat from client", key) + } + } + }() } }) if err != nil { diff --git a/services/tunnel_client.go b/services/tunnel_client.go index ff2e0d6..a77ad04 100644 --- a/services/tunnel_client.go +++ b/services/tunnel_client.go @@ -46,27 +46,33 @@ func (s *TunnelClient) Start(args interface{}) (err error) { for { ctrlConn, err := s.GetInConn(CONN_CONTROL, "") if err != nil { - log.Printf("control connection err: %s", err) + log.Printf("control connection err: %s, retrying...", err) time.Sleep(time.Second * 3) utils.CloseConn(&ctrlConn) continue } - // rw := utils.NewHeartbeatReadWriter(&ctrlConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { - // log.Printf("ctrlConn err %s", err) - // utils.CloseConn(&ctrlConn) - // }) + go func() { + for { + ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) + _, err = ctrlConn.Write([]byte{0x00}) + ctrlConn.SetWriteDeadline(time.Time{}) + if err != nil { + utils.CloseConn(&ctrlConn) + log.Printf("ctrlConn err %s", err) + break + } + time.Sleep(time.Second * 3) + } + }() for { signal := make([]byte, 50) - // n, err := rw.Read(signal) n, err := ctrlConn.Read(signal) if err != nil { utils.CloseConn(&ctrlConn) - log.Printf("read connection signal err: %s", err) + log.Printf("read connection signal err: %s, retrying...", err) break } addr := string(signal[:n]) - // log.Printf("n:%d addr:%s err:%s", n, addr, err) - // os.Exit(0) log.Printf("signal revecived:%s", addr) protocol := addr[:3] atIndex := strings.Index(addr, "@") diff --git a/utils/aes/aes.go b/utils/aes/aes.go new file mode 100644 index 0000000..3d4536e --- /dev/null +++ b/utils/aes/aes.go @@ -0,0 +1,84 @@ +// Playbook - http://play.golang.org/p/3wFl4lacjX + +package goaes + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "io" + "strings" +) + +func addBase64Padding(value string) string { + m := len(value) % 4 + if m != 0 { + value += strings.Repeat("=", 4-m) + } + + return value +} + +func removeBase64Padding(value string) string { + return strings.Replace(value, "=", "", -1) +} + +func Pad(src []byte) []byte { + padding := aes.BlockSize - len(src)%aes.BlockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padtext...) +} + +func Unpad(src []byte) ([]byte, error) { + length := len(src) + unpadding := int(src[length-1]) + + if unpadding > length { + return nil, errors.New("unpad error. This could happen when incorrect encryption key is used") + } + + return src[:(length - unpadding)], nil +} + +func Encrypt(key []byte, text []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + msg := Pad(text) + ciphertext := make([]byte, aes.BlockSize+len(msg)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + cfb := cipher.NewCFBEncrypter(block, iv) + cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(msg)) + + return ciphertext, nil +} + +func Decrypt(key []byte, text []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if (len(text) % aes.BlockSize) != 0 { + return nil, errors.New("blocksize must be multipe of decoded message length") + } + iv := text[:aes.BlockSize] + msg := text[aes.BlockSize:] + + cfb := cipher.NewCFBDecrypter(block, iv) + cfb.XORKeyStream(msg, msg) + + unpadMsg, err := Unpad(msg) + if err != nil { + return nil, err + } + + return unpadMsg, nil +} diff --git a/utils/functions.go b/utils/functions.go index 753c738..9ccc7a8 100755 --- a/utils/functions.go +++ b/utils/functions.go @@ -194,6 +194,9 @@ func HTTPGet(URL string, timeout int) (err error) { } func CloseConn(conn *net.Conn) { + defer func() { + _ = recover() + }() if conn != nil && *conn != nil { (*conn).SetDeadline(time.Now().Add(time.Millisecond)) (*conn).Close() diff --git a/utils/socks/structs.go b/utils/socks/structs.go new file mode 100644 index 0000000..fab2ed4 --- /dev/null +++ b/utils/socks/structs.go @@ -0,0 +1,254 @@ +package socks + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" +) + +const ( + Method_NO_AUTH = uint8(0x00) + Method_GSSAPI = uint8(0x01) + Method_USER_PASS = uint8(0x02) + Method_IANA = uint8(0x7F) + Method_RESVERVE = uint8(0x80) + Method_NONE_ACCEPTABLE = uint8(0xFF) + VERSION_V5 = uint8(0x05) + CMD_CONNECT = uint8(0x01) + CMD_BIND = uint8(0x02) + CMD_ASSOCIATE = uint8(0x03) + ATYP_IPV4 = uint8(0x01) + ATYP_DOMAIN = uint8(0x03) + ATYP_IPV6 = uint8(0x04) + REP_SUCCESS = uint8(0x00) + REP_REQ_FAIL = uint8(0x01) + REP_RULE_FORBIDDEN = uint8(0x02) + REP_NETWOR_UNREACHABLE = uint8(0x03) + REP_HOST_UNREACHABLE = uint8(0x04) + REP_CONNECTION_REFUSED = uint8(0x05) + REP_TTL_TIMEOUT = uint8(0x06) + REP_CMD_UNSUPPORTED = uint8(0x07) + REP_ATYP_UNSUPPORTED = uint8(0x08) + REP_UNKNOWN = uint8(0x09) + RSV = uint8(0x00) +) + +var ( + ZERO_IP = []byte{0x00, 0x00, 0x00, 0x00} + ZERO_PORT = []byte{0x00, 0x00} +) + +type Request struct { + ver uint8 + cmd uint8 + reserve uint8 + addressType uint8 + dstAddr string + dstPort string + dstHost string + bytes []byte + rw io.ReadWriter +} + +func NewRequest(rw io.ReadWriter) (req Request, err interface{}) { + var b [1024]byte + var n int + req = Request{rw: rw} + n, err = rw.Read(b[:]) + if err != nil { + err = fmt.Errorf("read req data fail,ERR: %s", err) + return + } + req.ver = uint8(b[0]) + req.cmd = uint8(b[1]) + req.reserve = uint8(b[2]) + req.addressType = uint8(b[3]) + + if b[0] != 0x5 { + err = fmt.Errorf("sosck version supported") + req.TCPReply(REP_REQ_FAIL) + return + } + switch b[3] { + case 0x01: //IP V4 + req.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String() + case 0x03: //域名 + req.dstHost = string(b[5 : n-2]) //b[4]表示域名的长度 + case 0x04: //IP V6 + req.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String() + } + req.dstPort = strconv.Itoa(int(b[n-2])<<8 | int(b[n-1])) + req.dstAddr = net.JoinHostPort(req.dstHost, req.dstPort) + req.bytes = b[:n] + return +} +func (s *Request) Bytes() []byte { + return s.bytes +} +func (s *Request) Addr() string { + return s.dstAddr +} +func (s *Request) Host() string { + return s.dstHost +} +func (s *Request) Port() string { + return s.dstPort +} +func (s *Request) AType() uint8 { + return s.addressType +} +func (s *Request) CMD() uint8 { + return s.cmd +} + +func (s *Request) TCPReply(rep uint8) (err error) { + _, err = s.rw.Write(s.NewReply(rep, "0.0.0.0:0")) + return +} +func (s *Request) UDPReply(rep uint8, addr string) (err error) { + _, err = s.rw.Write(s.NewReply(rep, addr)) + return +} +func (s *Request) NewReply(rep uint8, addr string) []byte { + var response bytes.Buffer + host, port, _ := net.SplitHostPort(addr) + ip := net.ParseIP(host) + ipb := ip.To4() + atyp := ATYP_IPV4 + ipv6 := ip.To16() + zeroiIPv6 := fmt.Sprintf("%d%d%d%d%d%d%d%d%d%d%d%d", + ipv6[0], ipv6[1], ipv6[2], ipv6[3], + ipv6[4], ipv6[5], ipv6[6], ipv6[7], + ipv6[8], ipv6[9], ipv6[10], ipv6[11], + ) + if ipv6 != nil && "0000000000255255" != zeroiIPv6 { + atyp = ATYP_IPV6 + ipb = ip.To16() + } + porti, _ := strconv.Atoi(port) + portb := make([]byte, 2) + binary.BigEndian.PutUint16(portb, uint16(porti)) + // log.Printf("atyp : %v", atyp) + // log.Printf("ip : %v", []byte(ip)) + response.WriteByte(VERSION_V5) + response.WriteByte(rep) + response.WriteByte(RSV) + response.WriteByte(atyp) + response.Write(ipb) + response.Write(portb) + return response.Bytes() +} + +type MethodsRequest struct { + ver uint8 + methodsCount uint8 + methods []uint8 + bytes []byte + rw *io.ReadWriter +} + +func NewMethodsRequest(r io.ReadWriter) (s MethodsRequest, err interface{}) { + defer func() { + if err == nil { + err = recover() + } + }() + s = MethodsRequest{} + s.rw = &r + var buf = make([]byte, 300) + var n int + n, err = r.Read(buf) + if err != nil { + return + } + if buf[0] != 0x05 { + err = fmt.Errorf("socks version not supported") + return + } + if n != int(buf[1])+int(2) { + err = fmt.Errorf("socks methods data length error") + return + } + + s.ver = buf[0] + s.methodsCount = buf[1] + s.methods = buf[2:n] + s.bytes = buf[:n] + return +} +func (s *MethodsRequest) Version() uint8 { + return s.ver +} +func (s *MethodsRequest) MethodsCount() uint8 { + return s.methodsCount +} +func (s *MethodsRequest) Select(method uint8) bool { + for _, m := range s.methods { + if m == method { + return true + } + } + return false +} +func (s *MethodsRequest) Reply(method uint8) (err error) { + _, err = (*s.rw).Write([]byte{byte(VERSION_V5), byte(method)}) + return +} +func (s *MethodsRequest) Bytes() []byte { + return s.bytes +} + +type UDPPacket struct { + rsv uint16 + frag uint8 + atype uint8 + dstHost string + dstPort string + data []byte + header []byte + bytes []byte +} + +func ParseUDPPacket(b []byte) (p UDPPacket, err error) { + p = UDPPacket{} + p.frag = uint8(b[2]) + p.bytes = b + if p.frag != 0 { + err = fmt.Errorf("FRAG only support for 0 , %v ,%v", p.frag, b[:4]) + return + } + portIndex := 0 + p.atype = b[3] + switch p.atype { + case ATYP_IPV4: //IP V4 + p.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String() + portIndex = 8 + case ATYP_DOMAIN: //域名 + domainLen := uint8(b[4]) + p.dstHost = string(b[5 : 5+domainLen]) //b[4]表示域名的长度 + portIndex = int(5 + domainLen) + case ATYP_IPV6: //IP V6 + p.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String() + portIndex = 20 + } + p.dstPort = strconv.Itoa(int(b[portIndex])<<8 | int(b[portIndex+1])) + p.data = b[portIndex+2:] + p.header = b[:portIndex+2] + return +} +func (s *UDPPacket) Header() []byte { + return s.header +} +func (s *UDPPacket) Host() string { + return s.dstHost +} + +func (s *UDPPacket) Port() string { + return s.dstPort +} +func (s *UDPPacket) Data() []byte { + return s.data +}