Signed-off-by: arraykeys@gmail.com <arraykeys@gmail.com>

This commit is contained in:
arraykeys@gmail.com
2017-10-18 18:01:53 +08:00
parent efb075c7ba
commit e28d5449b5
7 changed files with 603 additions and 106 deletions

View File

@ -1,4 +1,7 @@
proxy更新日志 proxy更新日志
v3.3
1.修复了socks代理模式对证书文件的判断逻辑.
v3.2 v3.2
1.内网穿透功能server端-r参数增加了协议和key设置. 1.内网穿透功能server端-r参数增加了协议和key设置.
2.手册增加了对-r参数的详细说明. 2.手册增加了对-r参数的详细说明.

View File

@ -1,15 +1,16 @@
package services package services
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"proxy/utils" "proxy/utils"
"proxy/utils/aes"
"proxy/utils/socks"
"runtime/debug"
"time" "time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -20,6 +21,8 @@ type Socks struct {
checker utils.Checker checker utils.Checker
basicAuth utils.BasicAuth basicAuth utils.BasicAuth
sshClient *ssh.Client sshClient *ssh.Client
lockChn chan bool
udpSC utils.ServerChannel
} }
func NewSocks() Service { func NewSocks() Service {
@ -27,16 +30,22 @@ func NewSocks() Service {
cfg: SocksArgs{}, cfg: SocksArgs{},
checker: utils.Checker{}, checker: utils.Checker{},
basicAuth: utils.BasicAuth{}, basicAuth: utils.BasicAuth{},
lockChn: make(chan bool, 1),
} }
} }
func (s *Socks) CheckArgs() { func (s *Socks) CheckArgs() {
var err error var err error
if *s.cfg.LocalType == "tls" {
log.Println(*s.cfg.CertFile, *s.cfg.KeyFile)
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
}
if *s.cfg.Parent != "" { if *s.cfg.Parent != "" {
if *s.cfg.ParentType == "" { if *s.cfg.ParentType == "" {
log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>") log.Fatalf("parent type unkown,use -T <tls|tcp|ssh>")
} }
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" { if *s.cfg.ParentType == "tls" {
log.Println(*s.cfg.CertFile, *s.cfg.KeyFile)
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
} }
if *s.cfg.ParentType == "ssh" { if *s.cfg.ParentType == "ssh" {
@ -77,11 +86,25 @@ func (s *Socks) InitService() {
log.Fatalf("init service fail, ERR: %s", err) log.Fatalf("init service fail, ERR: %s", err)
} }
} }
if *s.cfg.ParentType == "ssh" {
log.Println("warn: socks udp not suppored for ssh")
} else {
_, port, _ := net.SplitHostPort(*s.cfg.Local)
s.udpSC = utils.NewServerChannelHost(":" + port)
err := s.udpSC.ListenUDP(s.udpCallback)
if err != nil {
log.Fatalf("init udp service fail, ERR: %s", err)
}
log.Printf("udp socks proxy on %s", s.udpSC.UDPListener.LocalAddr())
}
} }
func (s *Socks) StopService() { func (s *Socks) StopService() {
if s.sshClient != nil { if s.sshClient != nil {
s.sshClient.Close() s.sshClient.Close()
} }
if s.udpSC.UDPListener != nil {
s.udpSC.UDPListener.Close()
}
} }
func (s *Socks) Start(args interface{}) (err error) { func (s *Socks) Start(args interface{}) (err error) {
//start() //start()
@ -93,9 +116,9 @@ func (s *Socks) Start(args interface{}) (err error) {
} }
sc := utils.NewServerChannelHost(*s.cfg.Local) sc := utils.NewServerChannelHost(*s.cfg.Local)
if *s.cfg.LocalType == TYPE_TCP { if *s.cfg.LocalType == TYPE_TCP {
err = sc.ListenTCP(s.callback) err = sc.ListenTCP(s.socksConnCallback)
} else { } else {
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.socksConnCallback)
} }
if err != nil { if err != nil {
return return
@ -106,95 +129,221 @@ func (s *Socks) Start(args interface{}) (err error) {
func (s *Socks) Clean() { func (s *Socks) Clean() {
s.StopService() s.StopService()
} }
func (s *Socks) callback(inConn net.Conn) { func (s *Socks) UDPKey() []byte {
return s.cfg.KeyBytes[:32]
}
func (s *Socks) udpCallback(b []byte, localAddr, srcAddr *net.UDPAddr) {
newB := b
var err error
if *s.cfg.LocalType == "tls" {
//decode b
newB, err = goaes.Decrypt(s.UDPKey(), b)
if err != nil {
log.Printf("decrypt udp packet fail from %s", srcAddr.String())
return
}
}
p, err := socks.ParseUDPPacket(newB)
log.Printf("udp revecived:%v", len(p.Data()))
if err != nil {
log.Printf("parse udp packet fail, ERR:%s", err)
return
}
//log.Printf("##########udp to -> %s:%s###########", p.Host(), p.Port())
if *s.cfg.Parent != "" {
//有上级代理,转发给上级
if *s.cfg.ParentType == "tls" {
//encode b
newB, err = goaes.Encrypt(s.UDPKey(), newB)
if err != nil {
log.Printf("encrypt udp data fail to %s", *s.cfg.Parent)
return
}
}
dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Parent)
if err != nil {
log.Printf("can't resolve address: %s", err)
return
}
clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
if err != nil {
log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
return
}
conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*2)))
_, err = conn.Write(newB)
log.Printf("udp request:%v", len(newB))
if err != nil {
log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
return
}
//log.Printf("send udp packet to %s success", dstAddr.String())
buf := make([]byte, 1024)
length, _, err := conn.ReadFromUDP(buf)
if err != nil {
log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
return
}
respBody := buf[0:length]
log.Printf("udp response:%v", len(respBody))
//log.Printf("revecived udp packet from %s", dstAddr.String())
if *s.cfg.ParentType == "tls" {
//decode b
respBody, err = goaes.Decrypt(s.UDPKey(), respBody)
if err != nil {
log.Printf("encrypt udp data fail to %s", *s.cfg.Parent)
return
}
}
if *s.cfg.LocalType == "tls" {
d, err := goaes.Encrypt(s.UDPKey(), respBody)
if err != nil {
log.Printf("encrypt udp data fail from %s", dstAddr.String())
return
}
s.udpSC.UDPListener.WriteToUDP(d, srcAddr)
log.Printf("udp reply:%v", len(d))
} else {
s.udpSC.UDPListener.WriteToUDP(respBody, srcAddr)
log.Printf("udp reply:%v", len(respBody))
}
} else {
//本地代理
dstAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.Host(), p.Port()))
if err != nil {
log.Printf("can't resolve address: %s", err)
return
}
clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
if err != nil {
log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
return
}
conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout*2)))
_, err = conn.Write(p.Data())
log.Printf("udp send:%v", len(p.Data()))
if err != nil {
log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
return
}
log.Printf("send udp packet to %s success", dstAddr.String())
buf := make([]byte, 1024)
length, _, err := conn.ReadFromUDP(buf)
if err != nil {
log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
return
}
respBody := buf[0:length]
//log.Printf("revecived udp packet from %s", dstAddr.String())
if *s.cfg.LocalType == "tls" {
d, err := goaes.Encrypt(s.UDPKey(), respBody)
if err != nil {
log.Printf("encrypt udp data fail from %s", dstAddr.String())
return
}
s.udpSC.UDPListener.WriteToUDP(d, srcAddr)
} else {
s.udpSC.UDPListener.WriteToUDP(respBody, srcAddr)
}
log.Printf("udp reply:%v", len(respBody))
}
}
func (s *Socks) socksConnCallback(inConn net.Conn) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
//log.Printf("socks conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) log.Printf("socks conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack()))
} }
utils.CloseConn(&inConn) utils.CloseConn(&inConn)
}() }()
var outConn net.Conn
defer utils.CloseConn(&outConn)
var b [1024]byte //method select request
n, err := inConn.Read(b[:]) methodReq, err := socks.NewMethodsRequest(inConn)
if err != nil { if err != nil || !methodReq.Select(socks.Method_NO_AUTH) {
if err != io.EOF { methodReq.Reply(socks.Method_NONE_ACCEPTABLE)
log.Printf("read request data fail,ERR: %s", err) utils.CloseConn(&inConn)
}
return return
} }
var reqBytes = b[:n] //method select reply
//log.Printf("% x", b[:n]) err = methodReq.Reply(socks.Method_NO_AUTH)
//reply
n, err = inConn.Write([]byte{0x05, 0x00})
if err != nil { if err != nil {
log.Printf("reply answer data fail,ERR: %s", err) log.Printf("reply answer data fail,ERR: %s", err)
utils.CloseConn(&inConn)
return return
} }
//read answer // log.Printf("% x", methodReq.Bytes())
n, err = inConn.Read(b[:])
//request detail
request, err := socks.NewRequest(inConn)
if err != nil { if err != nil {
log.Printf("read answer data fail,ERR: %s", err) log.Printf("read request data fail,ERR: %s", err)
utils.CloseConn(&inConn)
return return
} }
var headBytes = b[:n]
// log.Printf("% x", b[:n]) switch request.CMD() {
var addr string case socks.CMD_BIND:
switch b[3] { //bind 不支持
case 0x01: request.TCPReply(socks.REP_UNKNOWN)
sip := sockIP{} utils.CloseConn(&inConn)
if err := binary.Read(bytes.NewReader(b[4:n]), binary.BigEndian, &sip); err != nil { return
log.Printf("read ip fail,ERR: %s", err) case socks.CMD_CONNECT:
//tcp
s.proxyTCP(&inConn, methodReq, request)
case socks.CMD_ASSOCIATE:
//udp
s.proxyUDP(&inConn, methodReq, request)
}
}
func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) {
if *s.cfg.ParentType == "ssh" {
return return
} }
addr = sip.toAddr() host, _, _ := net.SplitHostPort((*inConn).LocalAddr().String())
case 0x03: _, port, _ := net.SplitHostPort(s.udpSC.UDPListener.LocalAddr().String())
host := string(b[5 : n-2]) // log.Printf("proxy udp on %s", net.JoinHostPort(host, port))
var port uint16 request.UDPReply(socks.REP_SUCCESS, net.JoinHostPort(host, port))
err = binary.Read(bytes.NewReader(b[n-2:n]), binary.BigEndian, &port) // log.Printf("%v", request.NewReply(socks.REP_SUCCESS, net.JoinHostPort(host, port)))
if err != nil {
log.Printf("read domain fail,ERR: %s", err)
return
}
addr = fmt.Sprintf("%s:%d", host, port)
} }
func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) {
var outConn net.Conn
defer utils.CloseConn(&outConn)
var err error
useProxy := true useProxy := true
if *s.cfg.Always { if *s.cfg.Always {
outConn, err = s.getOutConn(reqBytes, headBytes, addr) outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr())
} else { } else {
if *s.cfg.Parent != "" { if *s.cfg.Parent != "" {
s.checker.Add(addr, true, "", "", nil) s.checker.Add(request.Addr(), true, "", "", nil)
useProxy, _, _ = s.checker.IsBlocked(addr) useProxy, _, _ = s.checker.IsBlocked(request.Addr())
if useProxy { if useProxy {
outConn, err = s.getOutConn(reqBytes, headBytes, addr) outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr())
} else { } else {
outConn, err = utils.ConnectHost(addr, *s.cfg.Timeout) outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
} }
} else { } else {
outConn, err = utils.ConnectHost(addr, *s.cfg.Timeout) outConn, err = utils.ConnectHost(request.Addr(), *s.cfg.Timeout)
} }
} }
if err != nil { if err != nil {
log.Printf("get out conn fail,%s", err) log.Printf("get out conn fail,%s", err)
inConn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) request.TCPReply(socks.REP_NETWOR_UNREACHABLE)
return return
} }
log.Printf("use proxy %v : %s", useProxy, addr) log.Printf("use proxy %v : %s", useProxy, request.Addr())
inConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) request.TCPReply(socks.REP_SUCCESS)
inAddr := (*inConn).RemoteAddr().String()
inLocalAddr := (*inConn).LocalAddr().String()
inAddr := inConn.RemoteAddr().String() log.Printf("conn %s - %s connected [%s]", inAddr, inLocalAddr, request.Addr())
inLocalAddr := inConn.LocalAddr().String()
log.Printf("conn %s - %s connected [%s]", inAddr, inLocalAddr, addr)
// utils.IoBind(outConn, inConn, func(err error) {
// log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, addr)
// }, func(i int, b bool) {}, 0)
var bind = func() (err interface{}) { var bind = func() (err interface{}) {
defer func() { defer func() {
if err == nil { if err == nil {
@ -211,17 +360,17 @@ func (s *Socks) callback(inConn net.Conn) {
} }
} }
}() }()
_, err = io.Copy(outConn, inConn) _, err = io.Copy(outConn, (*inConn))
}() }()
_, err = io.Copy(inConn, outConn) _, err = io.Copy((*inConn), outConn)
return return
} }
bind() bind()
log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, addr) log.Printf("conn %s - %s released [%s]", inAddr, inLocalAddr, request.Addr())
utils.CloseConn(&inConn) utils.CloseConn(inConn)
utils.CloseConn(&outConn) utils.CloseConn(&outConn)
} }
func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net.Conn, err error) { func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn net.Conn, err error) {
switch *s.cfg.ParentType { switch *s.cfg.ParentType {
case "tls": case "tls":
fallthrough fallthrough
@ -238,7 +387,7 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net
} }
var buf = make([]byte, 1024) var buf = make([]byte, 1024)
//var n int //var n int
_, err = outConn.Write(reqBytes) _, err = outConn.Write(methodBytes)
if err != nil { if err != nil {
return return
} }
@ -249,7 +398,7 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net
//resp := buf[:n] //resp := buf[:n]
//log.Printf("resp:%v", resp) //log.Printf("resp:%v", resp)
outConn.Write(headBytes) outConn.Write(reqBytes)
_, err = outConn.Read(buf) _, err = outConn.Read(buf)
if err != nil { if err != nil {
return return
@ -267,7 +416,6 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net
outConn, err = s.sshClient.Dial("tcp", host) outConn, err = s.sshClient.Dial("tcp", host)
if err != nil { if err != nil {
log.Printf("connect ssh fail, ERR: %s, retrying...", err) log.Printf("connect ssh fail, ERR: %s, retrying...", err)
s.sshClient.Close()
e := s.ConnectSSH() e := s.ConnectSSH()
if e == nil { if e == nil {
tryCount++ tryCount++
@ -282,6 +430,12 @@ func (s *Socks) getOutConn(reqBytes, headBytes []byte, host string) (outConn net
return return
} }
func (s *Socks) ConnectSSH() (err error) { func (s *Socks) ConnectSSH() (err error) {
select {
case s.lockChn <- true:
default:
err = fmt.Errorf("can not connect at same time")
return
}
config := ssh.ClientConfig{ config := ssh.ClientConfig{
Timeout: time.Duration(*s.cfg.Timeout) * time.Millisecond, Timeout: time.Duration(*s.cfg.Timeout) * time.Millisecond,
User: *s.cfg.SSHUser, User: *s.cfg.SSHUser,
@ -290,15 +444,10 @@ func (s *Socks) ConnectSSH() (err error) {
return nil return nil
}, },
} }
if s.sshClient != nil {
s.sshClient.Close()
}
s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config) s.sshClient, err = ssh.Dial("tcp", *s.cfg.Parent, &config)
<-s.lockChn
return return
} }
type sockIP struct {
A, B, C, D byte
PORT uint16
}
func (ip sockIP) toAddr() string {
return fmt.Sprintf("%d.%d.%d.%d:%d", ip.A, ip.B, ip.C, ip.D, ip.PORT)
}

View File

@ -13,7 +13,6 @@ import (
type ServerConn struct { type ServerConn struct {
ClientLocalAddr string //tcp:2.2.22:333@ID ClientLocalAddr string //tcp:2.2.22:333@ID
Conn *net.Conn Conn *net.Conn
//Conn *utils.HeartbeatReadWriter
} }
type TunnelBridge struct { type TunnelBridge struct {
cfg TunnelBridgeArgs cfg TunnelBridgeArgs
@ -78,7 +77,6 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
return return
} }
key = string(_key) key = string(_key)
//log.Printf("conn key %s", key)
if connType != CONN_CONTROL { if connType != CONN_CONTROL {
var IDLength uint16 var IDLength uint16
@ -117,13 +115,8 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
switch connType { switch connType {
case CONN_SERVER: case CONN_SERVER:
// hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
// log.Printf("%s conn %s from server released", key, ID)
// s.serverConns.Remove(ID)
// })
addr := clientLocalAddr + "@" + ID addr := clientLocalAddr + "@" + ID
s.serverConns.Set(ID, ServerConn{ s.serverConns.Set(ID, ServerConn{
//Conn: &hb,
Conn: &inConn, Conn: &inConn,
ClientLocalAddr: addr, ClientLocalAddr: addr,
}) })
@ -134,7 +127,9 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
(*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err := (*item.(*net.Conn)).Write([]byte(addr)) _, err := (*item.(*net.Conn)).Write([]byte(addr))
(*item.(*net.Conn)).SetWriteDeadline(time.Time{})
if err != nil { if err != nil {
log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@ -151,33 +146,36 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
return return
} }
serverConn := serverConnItem.(ServerConn).Conn serverConn := serverConnItem.(ServerConn).Conn
// hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) {
// log.Printf("%s conn %s from client released", key, ID)
// hw.Close()
// })
utils.IoBind(*serverConn, inConn, func(err error) { utils.IoBind(*serverConn, inConn, func(err error) {
// utils.IoBind(serverConn, inConn, func(isSrcErr bool, err error) {
//serverConn.Close()
(*serverConn).Close() (*serverConn).Close()
utils.CloseConn(&inConn) utils.CloseConn(&inConn)
// hw.Close()
s.serverConns.Remove(ID) s.serverConns.Remove(ID)
log.Printf("conn %s released", ID) log.Printf("conn %s released", ID)
}, func(i int, b bool) {}, 0) }, func(i int, b bool) {}, 0)
log.Printf("conn %s created", ID) log.Printf("conn %s created", ID)
case CONN_CONTROL: case CONN_CONTROL:
if s.clientControlConns.Has(key) { if s.clientControlConns.Has(key) {
item, _ := s.clientControlConns.Get(key) item, _ := s.clientControlConns.Get(key)
//(*item.(*utils.HeartbeatReadWriter)).Close()
(*item.(*net.Conn)).Close() (*item.(*net.Conn)).Close()
} }
// hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
// log.Printf("client %s disconnected", key)
// s.clientControlConns.Remove(key)
// })
// s.clientControlConns.Set(key, &hb)
s.clientControlConns.Set(key, &inConn) s.clientControlConns.Set(key, &inConn)
log.Printf("set client %s control conn", key) log.Printf("set client %s control conn", key)
go func() {
for {
var b = make([]byte, 1)
_, err = inConn.Read(b)
if err != nil {
inConn.Close()
s.serverConns.Remove(ID)
log.Printf("%s control conn from client released", key)
break
} else {
//log.Printf("%s heartbeat from client", key)
}
}
}()
} }
}) })
if err != nil { if err != nil {

View File

@ -46,27 +46,33 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
for { for {
ctrlConn, err := s.GetInConn(CONN_CONTROL, "") ctrlConn, err := s.GetInConn(CONN_CONTROL, "")
if err != nil { if err != nil {
log.Printf("control connection err: %s", err) log.Printf("control connection err: %s, retrying...", err)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
utils.CloseConn(&ctrlConn) utils.CloseConn(&ctrlConn)
continue continue
} }
// rw := utils.NewHeartbeatReadWriter(&ctrlConn, 3, func(err error, hb *utils.HeartbeatReadWriter) { go func() {
// log.Printf("ctrlConn err %s", err) for {
// utils.CloseConn(&ctrlConn) ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
// }) _, err = ctrlConn.Write([]byte{0x00})
ctrlConn.SetWriteDeadline(time.Time{})
if err != nil {
utils.CloseConn(&ctrlConn)
log.Printf("ctrlConn err %s", err)
break
}
time.Sleep(time.Second * 3)
}
}()
for { for {
signal := make([]byte, 50) signal := make([]byte, 50)
// n, err := rw.Read(signal)
n, err := ctrlConn.Read(signal) n, err := ctrlConn.Read(signal)
if err != nil { if err != nil {
utils.CloseConn(&ctrlConn) utils.CloseConn(&ctrlConn)
log.Printf("read connection signal err: %s", err) log.Printf("read connection signal err: %s, retrying...", err)
break break
} }
addr := string(signal[:n]) addr := string(signal[:n])
// log.Printf("n:%d addr:%s err:%s", n, addr, err)
// os.Exit(0)
log.Printf("signal revecived:%s", addr) log.Printf("signal revecived:%s", addr)
protocol := addr[:3] protocol := addr[:3]
atIndex := strings.Index(addr, "@") atIndex := strings.Index(addr, "@")

84
utils/aes/aes.go Normal file
View File

@ -0,0 +1,84 @@
// Playbook - http://play.golang.org/p/3wFl4lacjX
package goaes
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"io"
"strings"
)
func addBase64Padding(value string) string {
m := len(value) % 4
if m != 0 {
value += strings.Repeat("=", 4-m)
}
return value
}
func removeBase64Padding(value string) string {
return strings.Replace(value, "=", "", -1)
}
func Pad(src []byte) []byte {
padding := aes.BlockSize - len(src)%aes.BlockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
func Unpad(src []byte) ([]byte, error) {
length := len(src)
unpadding := int(src[length-1])
if unpadding > length {
return nil, errors.New("unpad error. This could happen when incorrect encryption key is used")
}
return src[:(length - unpadding)], nil
}
func Encrypt(key []byte, text []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
msg := Pad(text)
ciphertext := make([]byte, aes.BlockSize+len(msg))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(msg))
return ciphertext, nil
}
func Decrypt(key []byte, text []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
if (len(text) % aes.BlockSize) != 0 {
return nil, errors.New("blocksize must be multipe of decoded message length")
}
iv := text[:aes.BlockSize]
msg := text[aes.BlockSize:]
cfb := cipher.NewCFBDecrypter(block, iv)
cfb.XORKeyStream(msg, msg)
unpadMsg, err := Unpad(msg)
if err != nil {
return nil, err
}
return unpadMsg, nil
}

View File

@ -194,6 +194,9 @@ func HTTPGet(URL string, timeout int) (err error) {
} }
func CloseConn(conn *net.Conn) { func CloseConn(conn *net.Conn) {
defer func() {
_ = recover()
}()
if conn != nil && *conn != nil { if conn != nil && *conn != nil {
(*conn).SetDeadline(time.Now().Add(time.Millisecond)) (*conn).SetDeadline(time.Now().Add(time.Millisecond))
(*conn).Close() (*conn).Close()

254
utils/socks/structs.go Normal file
View File

@ -0,0 +1,254 @@
package socks
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
)
const (
Method_NO_AUTH = uint8(0x00)
Method_GSSAPI = uint8(0x01)
Method_USER_PASS = uint8(0x02)
Method_IANA = uint8(0x7F)
Method_RESVERVE = uint8(0x80)
Method_NONE_ACCEPTABLE = uint8(0xFF)
VERSION_V5 = uint8(0x05)
CMD_CONNECT = uint8(0x01)
CMD_BIND = uint8(0x02)
CMD_ASSOCIATE = uint8(0x03)
ATYP_IPV4 = uint8(0x01)
ATYP_DOMAIN = uint8(0x03)
ATYP_IPV6 = uint8(0x04)
REP_SUCCESS = uint8(0x00)
REP_REQ_FAIL = uint8(0x01)
REP_RULE_FORBIDDEN = uint8(0x02)
REP_NETWOR_UNREACHABLE = uint8(0x03)
REP_HOST_UNREACHABLE = uint8(0x04)
REP_CONNECTION_REFUSED = uint8(0x05)
REP_TTL_TIMEOUT = uint8(0x06)
REP_CMD_UNSUPPORTED = uint8(0x07)
REP_ATYP_UNSUPPORTED = uint8(0x08)
REP_UNKNOWN = uint8(0x09)
RSV = uint8(0x00)
)
var (
ZERO_IP = []byte{0x00, 0x00, 0x00, 0x00}
ZERO_PORT = []byte{0x00, 0x00}
)
type Request struct {
ver uint8
cmd uint8
reserve uint8
addressType uint8
dstAddr string
dstPort string
dstHost string
bytes []byte
rw io.ReadWriter
}
func NewRequest(rw io.ReadWriter) (req Request, err interface{}) {
var b [1024]byte
var n int
req = Request{rw: rw}
n, err = rw.Read(b[:])
if err != nil {
err = fmt.Errorf("read req data fail,ERR: %s", err)
return
}
req.ver = uint8(b[0])
req.cmd = uint8(b[1])
req.reserve = uint8(b[2])
req.addressType = uint8(b[3])
if b[0] != 0x5 {
err = fmt.Errorf("sosck version supported")
req.TCPReply(REP_REQ_FAIL)
return
}
switch b[3] {
case 0x01: //IP V4
req.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String()
case 0x03: //域名
req.dstHost = string(b[5 : n-2]) //b[4]表示域名的长度
case 0x04: //IP V6
req.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String()
}
req.dstPort = strconv.Itoa(int(b[n-2])<<8 | int(b[n-1]))
req.dstAddr = net.JoinHostPort(req.dstHost, req.dstPort)
req.bytes = b[:n]
return
}
func (s *Request) Bytes() []byte {
return s.bytes
}
func (s *Request) Addr() string {
return s.dstAddr
}
func (s *Request) Host() string {
return s.dstHost
}
func (s *Request) Port() string {
return s.dstPort
}
func (s *Request) AType() uint8 {
return s.addressType
}
func (s *Request) CMD() uint8 {
return s.cmd
}
func (s *Request) TCPReply(rep uint8) (err error) {
_, err = s.rw.Write(s.NewReply(rep, "0.0.0.0:0"))
return
}
func (s *Request) UDPReply(rep uint8, addr string) (err error) {
_, err = s.rw.Write(s.NewReply(rep, addr))
return
}
func (s *Request) NewReply(rep uint8, addr string) []byte {
var response bytes.Buffer
host, port, _ := net.SplitHostPort(addr)
ip := net.ParseIP(host)
ipb := ip.To4()
atyp := ATYP_IPV4
ipv6 := ip.To16()
zeroiIPv6 := fmt.Sprintf("%d%d%d%d%d%d%d%d%d%d%d%d",
ipv6[0], ipv6[1], ipv6[2], ipv6[3],
ipv6[4], ipv6[5], ipv6[6], ipv6[7],
ipv6[8], ipv6[9], ipv6[10], ipv6[11],
)
if ipv6 != nil && "0000000000255255" != zeroiIPv6 {
atyp = ATYP_IPV6
ipb = ip.To16()
}
porti, _ := strconv.Atoi(port)
portb := make([]byte, 2)
binary.BigEndian.PutUint16(portb, uint16(porti))
// log.Printf("atyp : %v", atyp)
// log.Printf("ip : %v", []byte(ip))
response.WriteByte(VERSION_V5)
response.WriteByte(rep)
response.WriteByte(RSV)
response.WriteByte(atyp)
response.Write(ipb)
response.Write(portb)
return response.Bytes()
}
type MethodsRequest struct {
ver uint8
methodsCount uint8
methods []uint8
bytes []byte
rw *io.ReadWriter
}
func NewMethodsRequest(r io.ReadWriter) (s MethodsRequest, err interface{}) {
defer func() {
if err == nil {
err = recover()
}
}()
s = MethodsRequest{}
s.rw = &r
var buf = make([]byte, 300)
var n int
n, err = r.Read(buf)
if err != nil {
return
}
if buf[0] != 0x05 {
err = fmt.Errorf("socks version not supported")
return
}
if n != int(buf[1])+int(2) {
err = fmt.Errorf("socks methods data length error")
return
}
s.ver = buf[0]
s.methodsCount = buf[1]
s.methods = buf[2:n]
s.bytes = buf[:n]
return
}
func (s *MethodsRequest) Version() uint8 {
return s.ver
}
func (s *MethodsRequest) MethodsCount() uint8 {
return s.methodsCount
}
func (s *MethodsRequest) Select(method uint8) bool {
for _, m := range s.methods {
if m == method {
return true
}
}
return false
}
func (s *MethodsRequest) Reply(method uint8) (err error) {
_, err = (*s.rw).Write([]byte{byte(VERSION_V5), byte(method)})
return
}
func (s *MethodsRequest) Bytes() []byte {
return s.bytes
}
type UDPPacket struct {
rsv uint16
frag uint8
atype uint8
dstHost string
dstPort string
data []byte
header []byte
bytes []byte
}
func ParseUDPPacket(b []byte) (p UDPPacket, err error) {
p = UDPPacket{}
p.frag = uint8(b[2])
p.bytes = b
if p.frag != 0 {
err = fmt.Errorf("FRAG only support for 0 , %v ,%v", p.frag, b[:4])
return
}
portIndex := 0
p.atype = b[3]
switch p.atype {
case ATYP_IPV4: //IP V4
p.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String()
portIndex = 8
case ATYP_DOMAIN: //域名
domainLen := uint8(b[4])
p.dstHost = string(b[5 : 5+domainLen]) //b[4]表示域名的长度
portIndex = int(5 + domainLen)
case ATYP_IPV6: //IP V6
p.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String()
portIndex = 20
}
p.dstPort = strconv.Itoa(int(b[portIndex])<<8 | int(b[portIndex+1]))
p.data = b[portIndex+2:]
p.header = b[:portIndex+2]
return
}
func (s *UDPPacket) Header() []byte {
return s.header
}
func (s *UDPPacket) Host() string {
return s.dstHost
}
func (s *UDPPacket) Port() string {
return s.dstPort
}
func (s *UDPPacket) Data() []byte {
return s.data
}