Signed-off-by: arraykeys@gmail.com <arraykeys@gmail.com>

This commit is contained in:
arraykeys@gmail.com
2017-09-30 19:00:46 +08:00
parent 635e107d66
commit 69fcc7d12e
13 changed files with 618 additions and 188 deletions

View File

@ -1,4 +1,8 @@
proxy更新日志:
v3.1
1.优化了内网穿透功能,client只需要启动一个即可,server端启动的时候可以指定client端要暴露的端口。
2.
v3.0
1.此次更新不兼容2.x版本,重构了全部代码,架构更合理,利于功能模块的增加与维护。
2.增加了代理死循环检查,增强了安全性。

View File

@ -37,7 +37,6 @@ func initConfig() (err error) {
app = kingpin.New("proxy", "happy with proxy")
app.Author("snail").Version(APP_VERSION)
args.Parent = app.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
args.Local = app.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
certTLS := app.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
keyTLS := app.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
@ -55,6 +54,7 @@ func initConfig() (err error) {
httpArgs.Auth = http.Flag("auth", "http basic auth username and password, mutiple user repeat -a ,such as: -a user1:pass1 -a user2:pass2").Short('a').Strings()
httpArgs.PoolSize = http.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Short('L').Default("20").Int()
httpArgs.CheckParentInterval = http.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int()
httpArgs.Local = http.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
//########tcp#########
tcp := app.Command("tcp", "proxy on tcp mode")
@ -63,6 +63,7 @@ func initConfig() (err error) {
tcpArgs.IsTLS = tcp.Flag("tls", "proxy on tls mode").Default("false").Bool()
tcpArgs.PoolSize = tcp.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Short('L').Default("20").Int()
tcpArgs.CheckParentInterval = tcp.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int()
tcpArgs.Local = tcp.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
//########udp#########
udp := app.Command("udp", "proxy on udp mode")
@ -70,22 +71,25 @@ func initConfig() (err error) {
udpArgs.ParentType = udp.Flag("parent-type", "parent protocol type <tls|tcp|udp>").Short('T').Enum("tls", "tcp", "udp")
udpArgs.PoolSize = udp.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Short('L').Default("20").Int()
udpArgs.CheckParentInterval = udp.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int()
udpArgs.Local = udp.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
//########tunnel-server#########
tunnelServer := app.Command("tserver", "proxy on tunnel server mode")
tunnelServerArgs.Timeout = tunnelServer.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int()
tunnelServerArgs.IsUDP = tunnelServer.Flag("udp", "proxy on udp tunnel server mode").Default("false").Bool()
tunnelServerArgs.Key = tunnelServer.Flag("k", "key same with client").Default("default").String()
tunnelServerArgs.Key = tunnelServer.Flag("k", "client key").Default("default").String()
tunnelServerArgs.Remote = tunnelServer.Flag("remote", "client's network host:port").Short('r').Default("").String()
tunnelServerArgs.Local = tunnelServer.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
//########tunnel-client#########
tunnelClient := app.Command("tclient", "proxy on tunnel client mode")
tunnelClientArgs.Timeout = tunnelClient.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int()
tunnelClientArgs.IsUDP = tunnelClient.Flag("udp", "proxy on udp tunnel client mode").Default("false").Bool()
tunnelClientArgs.Key = tunnelClient.Flag("k", "key same with server").Default("default").String()
//########tunnel-bridge#########
tunnelBridge := app.Command("tbridge", "proxy on tunnel bridge mode")
tunnelBridgeArgs.Timeout = tunnelBridge.Flag("timeout", "tcp timeout with milliseconds").Short('t').Default("2000").Int()
tunnelBridgeArgs.Local = tunnelBridge.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
kingpin.MustParse(app.Parse(os.Args[1:]))

View File

@ -1,10 +1,6 @@
#!/bin/bash
set -e
if [ -e /tmp/proxy ]; then
rm -rf /tmp/proxy
fi
mkdir /tmp/proxy
cd /tmp/proxy
# install monexec
tar zxvf monexec_0.1.1_linux_amd64.tar.gz
cd monexec_0.1.1_linux_amd64

View File

@ -1,5 +1,5 @@
#!/bin/bash
VER="3.0"
VER="3.1"
RELEASE="release-${VER}"
rm -rf .cert
mkdir .cert

View File

@ -14,29 +14,31 @@ const (
)
type Args struct {
Local *string
Parent *string
CertBytes []byte
KeyBytes []byte
}
type TunnelServerArgs struct {
Args
Local *string
IsUDP *bool
Key *string
Remote *string
Timeout *int
}
type TunnelClientArgs struct {
Args
IsUDP *bool
Key *string
Timeout *int
}
type TunnelBridgeArgs struct {
Args
Local *string
Timeout *int
}
type TCPArgs struct {
Args
Local *string
ParentType *string
IsTLS *bool
Timeout *int
@ -46,6 +48,7 @@ type TCPArgs struct {
type HTTPArgs struct {
Args
Local *string
Always *bool
HTTPTimeout *int
Interval *int
@ -61,6 +64,7 @@ type HTTPArgs struct {
}
type UDPArgs struct {
Args
Local *string
ParentType *string
Timeout *int
PoolSize *int

View File

@ -1,6 +1,7 @@
package services
import (
"bufio"
"fmt"
"io"
"log"
@ -108,7 +109,7 @@ func (s *TCP) OutToTCP(inConn *net.Conn) (err error) {
func (s *TCP) OutToUDP(inConn *net.Conn) (err error) {
log.Printf("conn created , remote : %s ", (*inConn).RemoteAddr())
for {
srcAddr, body, err := utils.ReadUDPPacket(inConn)
srcAddr, body, err := utils.ReadUDPPacket(bufio.NewReader(*inConn))
if err == io.EOF || err == io.ErrUnexpectedEOF {
//log.Printf("connection %s released", srcAddr)
utils.CloseConn(inConn)

View File

@ -3,31 +3,28 @@ package services
import (
"bufio"
"encoding/binary"
"fmt"
"log"
"net"
"proxy/utils"
"strconv"
"sync"
"time"
)
type BridgeItem struct {
ServerChn chan *net.Conn
ClientChn chan *net.Conn
ClientControl *net.Conn
Once *sync.Once
Key string
type ServerConn struct {
ClientLocalAddr string //tcp:2.2.22:333@ID
Conn *net.Conn
}
type TunnelBridge struct {
cfg TunnelBridgeArgs
br utils.ConcurrentMap
cfg TunnelBridgeArgs
serverConns utils.ConcurrentMap
clientControlConns utils.ConcurrentMap
}
func NewTunnelBridge() Service {
return &TunnelBridge{
cfg: TunnelBridgeArgs{},
br: utils.NewConcurrentMap(),
cfg: TunnelBridgeArgs{},
serverConns: utils.NewConcurrentMap(),
clientControlConns: utils.NewConcurrentMap(),
}
}
@ -52,6 +49,8 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
sc := utils.NewServerChannel(host, p)
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) {
//log.Printf("connection from %s ", inConn.RemoteAddr())
reader := bufio.NewReader(inConn)
var connType uint8
err = binary.Read(reader, binary.LittleEndian, &connType)
@ -59,35 +58,116 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
utils.CloseConn(&inConn)
return
}
var key string
//log.Printf("conn type %d", connType)
var key, clientLocalAddr, ID string
var connTypeStrMap = map[uint8]string{CONN_SERVER: "server", CONN_CLIENT: "client", CONN_CONTROL: "client"}
if connType == CONN_SERVER || connType == CONN_CLIENT || connType == CONN_CONTROL {
var keyLength uint16
err = binary.Read(reader, binary.LittleEndian, &keyLength)
if err != nil {
return
}
_key := make([]byte, keyLength)
n, err := reader.Read(_key)
if err != nil {
return
}
if n != int(keyLength) {
return
}
key = string(_key)
log.Printf("connection from %s , key: %s", connTypeStrMap[connType], key)
var keyLength uint16
err = binary.Read(reader, binary.LittleEndian, &keyLength)
if err != nil {
return
}
_key := make([]byte, keyLength)
n, err := reader.Read(_key)
if err != nil {
return
}
if n != int(keyLength) {
return
}
key = string(_key)
//log.Printf("conn key %s", key)
if connType != CONN_CONTROL {
var IDLength uint16
err = binary.Read(reader, binary.LittleEndian, &IDLength)
if err != nil {
return
}
_id := make([]byte, IDLength)
n, err := reader.Read(_id)
if err != nil {
return
}
if n != int(IDLength) {
return
}
ID = string(_id)
if connType == CONN_SERVER {
var addrLength uint16
err = binary.Read(reader, binary.LittleEndian, &addrLength)
if err != nil {
return
}
_addr := make([]byte, addrLength)
n, err = reader.Read(_addr)
if err != nil {
return
}
if n != int(addrLength) {
return
}
clientLocalAddr = string(_addr)
}
}
log.Printf("connection from %s , key: %s , id: %s", connTypeStrMap[connType], key, ID)
switch connType {
case CONN_SERVER:
s.ServerConn(&inConn, key)
addr := clientLocalAddr + "@" + ID
s.serverConns.Set(ID, ServerConn{
Conn: &inConn,
ClientLocalAddr: addr,
})
for {
item, ok := s.clientControlConns.Get(key)
if !ok {
log.Printf("client %s control conn not exists", key)
time.Sleep(time.Second * 3)
continue
}
_, err := (*item.(*utils.HeartbeatReadWriter)).Write([]byte(addr))
if err != nil {
log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
time.Sleep(time.Second * 3)
continue
} else {
break
}
}
case CONN_CLIENT:
s.ClientConn(&inConn, key)
serverConnItem, ok := s.serverConns.Get(ID)
if !ok {
inConn.Close()
log.Printf("server conn %s exists", ID)
return
}
serverConn := serverConnItem.(ServerConn).Conn
// hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) {
// log.Printf("hw err %s", err)
// hw.Close()
// })
// utils.IoBind(*serverConn, &hw, func(isSrcErr bool, err error) {
utils.IoBind(*serverConn, inConn, func(isSrcErr bool, err error) {
utils.CloseConn(serverConn)
utils.CloseConn(&inConn)
s.serverConns.Remove(ID)
log.Printf("conn %s released", ID)
}, func(i int, b bool) {}, 0)
log.Printf("conn %s created", ID)
case CONN_CONTROL:
s.ClientControlConn(&inConn, key)
default:
log.Printf("unkown conn type %d", connType)
utils.CloseConn(&inConn)
if s.clientControlConns.Has(key) {
item, _ := s.clientControlConns.Get(key)
(*item.(*utils.HeartbeatReadWriter)).Close()
}
hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
log.Printf("client %s disconnected", key)
s.clientControlConns.Remove(key)
})
s.clientControlConns.Set(key, &hb)
log.Printf("set client %s control conn", key)
}
})
if err != nil {
@ -99,85 +179,3 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
func (s *TunnelBridge) Clean() {
s.StopService()
}
func (s *TunnelBridge) ClientConn(inConn *net.Conn, key string) {
chn, _ := s.ConnChn(key, CONN_CLIENT)
chn <- inConn
}
func (s *TunnelBridge) ServerConn(inConn *net.Conn, key string) {
chn, _ := s.ConnChn(key, CONN_SERVER)
chn <- inConn
}
func (s *TunnelBridge) ClientControlConn(inConn *net.Conn, key string) {
_, item := s.ConnChn(key, CONN_CLIENT)
utils.CloseConn(item.ClientControl)
if item.ClientControl != nil {
*item.ClientControl = *inConn
} else {
item.ClientControl = inConn
}
log.Printf("set client control conn,remote: %s", (*inConn).RemoteAddr())
}
func (s *TunnelBridge) ConnChn(key string, typ uint8) (chn chan *net.Conn, item *BridgeItem) {
s.br.SetIfAbsent(key, &BridgeItem{
ServerChn: make(chan *net.Conn, 10000),
ClientChn: make(chan *net.Conn, 10000),
Once: &sync.Once{},
Key: key,
})
_item, _ := s.br.Get(key)
item = _item.(*BridgeItem)
item.Once.Do(func() {
s.ChnDeamon(item)
})
if typ == CONN_CLIENT {
chn = item.ClientChn
} else {
chn = item.ServerChn
}
return
}
func (s *TunnelBridge) ChnDeamon(item *BridgeItem) {
go func() {
log.Printf("%s conn chan deamon started", item.Key)
for {
var clientConn *net.Conn
var serverConn *net.Conn
serverConn = <-item.ServerChn
log.Printf("%s server conn picked up", item.Key)
OUT:
for {
_item, _ := s.br.Get(item.Key)
Item := _item.(*BridgeItem)
var err error
if Item.ClientControl != nil && *Item.ClientControl != nil {
_, err = (*Item.ClientControl).Write([]byte{'0'})
} else {
err = fmt.Errorf("client control conn not exists")
}
if err != nil {
log.Printf("%s client control conn write signal fail, err: %s, retrying...", item.Key, err)
utils.CloseConn(Item.ClientControl)
*Item.ClientControl = nil
Item.ClientControl = nil
time.Sleep(time.Second * 3)
continue
} else {
select {
case clientConn = <-item.ClientChn:
log.Printf("%s client conn picked up", item.Key)
break OUT
case <-time.After(time.Second * time.Duration(*s.cfg.Timeout*5)):
log.Printf("%s client conn picked timeout, retrying...", item.Key)
}
}
}
utils.IoBind(*serverConn, *clientConn, func(isSrcErr bool, err error) {
utils.CloseConn(serverConn)
utils.CloseConn(clientConn)
log.Printf("%s conn %s - %s - %s - %s released", item.Key, (*serverConn).RemoteAddr(), (*serverConn).LocalAddr(), (*clientConn).LocalAddr(), (*clientConn).RemoteAddr())
}, func(i int, b bool) {}, 0)
log.Printf("%s conn %s - %s - %s - %s created", item.Key, (*serverConn).RemoteAddr(), (*serverConn).LocalAddr(), (*clientConn).LocalAddr(), (*clientConn).RemoteAddr())
}
}()
}

View File

@ -9,6 +9,7 @@ import (
"log"
"net"
"proxy/utils"
"strings"
"time"
)
@ -40,36 +41,39 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
s.cfg = args.(TunnelClientArgs)
s.Check()
s.InitService()
log.Printf("proxy on tunnel client mode")
for {
ctrlConn, err := s.GetInConn(CONN_CONTROL)
ctrlConn, err := s.GetInConn(CONN_CONTROL, "")
if err != nil {
log.Printf("control connection err: %s", err)
time.Sleep(time.Second * 3)
utils.CloseConn(&ctrlConn)
continue
}
if *s.cfg.IsUDP {
log.Printf("proxy on udp tunnel client mode")
} else {
log.Printf("proxy on tcp tunnel client mode")
}
rw := utils.NewHeartbeatReadWriter(&ctrlConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
log.Printf("ctrlConn err %s", err)
utils.CloseConn(&ctrlConn)
})
for {
signal := make([]byte, 1)
if signal[0] == 1 {
continue
}
_, err = ctrlConn.Read(signal)
signal := make([]byte, 50)
n, err := rw.Read(signal)
if err != nil {
utils.CloseConn(&ctrlConn)
log.Printf("read connection signal err: %s", err)
break
}
log.Printf("signal revecived:%s", signal)
if *s.cfg.IsUDP {
go s.ServeUDP()
addr := string(signal[:n])
// log.Printf("n:%d addr:%s err:%s", n, addr, err)
// os.Exit(0)
log.Printf("signal revecived:%s", addr)
protocol := addr[:3]
atIndex := strings.Index(addr, "@")
ID := addr[atIndex+1:]
localAddr := addr[4:atIndex]
if protocol == "udp" {
go s.ServeUDP(localAddr, ID)
} else {
go s.ServeConn()
go s.ServeConn(localAddr, ID)
}
}
}
@ -77,7 +81,7 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
func (s *TunnelClient) Clean() {
s.StopService()
}
func (s *TunnelClient) GetInConn(typ uint8) (outConn net.Conn, err error) {
func (s *TunnelClient) GetInConn(typ uint8, ID string) (outConn net.Conn, err error) {
outConn, err = s.GetConn()
if err != nil {
err = fmt.Errorf("connection err: %s", err)
@ -89,6 +93,12 @@ func (s *TunnelClient) GetInConn(typ uint8) (outConn net.Conn, err error) {
binary.Write(pkg, binary.LittleEndian, typ)
binary.Write(pkg, binary.LittleEndian, keyLength)
binary.Write(pkg, binary.LittleEndian, keyBytes)
if ID != "" {
IDBytes := []byte(ID)
IDLength := uint16(len(IDBytes))
binary.Write(pkg, binary.LittleEndian, IDLength)
binary.Write(pkg, binary.LittleEndian, IDBytes)
}
_, err = outConn.Write(pkg.Bytes())
if err != nil {
err = fmt.Errorf("write connection data err: %s ,retrying...", err)
@ -105,36 +115,41 @@ func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
}
return
}
func (s *TunnelClient) ServeUDP() {
func (s *TunnelClient) ServeUDP(localAddr, ID string) {
var inConn net.Conn
var err error
// for {
for {
for {
inConn, err = s.GetInConn(CONN_CLIENT)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("connection err: %s, retrying...", err)
time.Sleep(time.Second * 3)
continue
} else {
break
}
}
log.Printf("conn created , remote : %s ", inConn.RemoteAddr())
for {
srcAddr, body, err := utils.ReadUDPPacket(&inConn)
if err == io.EOF || err == io.ErrUnexpectedEOF {
log.Printf("connection %s released", srcAddr)
utils.CloseConn(&inConn)
break
}
//log.Printf("udp packet revecived:%s,%v", srcAddr, body)
go s.processUDPPacket(&inConn, srcAddr, body)
inConn, err = s.GetInConn(CONN_CLIENT, ID)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("connection err: %s, retrying...", err)
time.Sleep(time.Second * 3)
continue
} else {
break
}
}
log.Printf("conn %s created", ID)
// hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) {
// log.Printf("hw err %s", err)
// hw.Close()
// })
for {
// srcAddr, body, err := utils.ReadUDPPacket(&hw)
srcAddr, body, err := utils.ReadUDPPacket(inConn)
if err == io.EOF || err == io.ErrUnexpectedEOF {
log.Printf("connection %s released", ID)
utils.CloseConn(&inConn)
break
}
//log.Printf("udp packet revecived:%s,%v", srcAddr, body)
go s.processUDPPacket(&inConn, srcAddr, localAddr, body)
}
// }
}
func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr string, body []byte) {
dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Local)
func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr string, body []byte) {
dstAddr, err := net.ResolveUDPAddr("udp", localAddr)
if err != nil {
log.Printf("can't resolve address: %s", err)
utils.CloseConn(inConn)
@ -169,11 +184,11 @@ func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr string, body [
}
//log.Printf("send udp response success ,from:%s", dstAddr.String())
}
func (s *TunnelClient) ServeConn() {
func (s *TunnelClient) ServeConn(localAddr, ID string) {
var inConn, outConn net.Conn
var err error
for {
inConn, err = s.GetInConn(CONN_CLIENT)
inConn, err = s.GetInConn(CONN_CLIENT, ID)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("connection err: %s, retrying...", err)
@ -187,29 +202,27 @@ func (s *TunnelClient) ServeConn() {
i := 0
for {
i++
outConn, err = utils.ConnectHost(*s.cfg.Local, *s.cfg.Timeout)
outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout)
if err == nil || i == 3 {
break
} else {
if i == 3 {
log.Printf("connect to %s err: %s, retrying...", *s.cfg.Local, err)
log.Printf("connect to %s err: %s, retrying...", localAddr, err)
time.Sleep(2 * time.Second)
continue
}
}
}
if err != nil {
utils.CloseConn(&inConn)
utils.CloseConn(&outConn)
log.Printf("build connection error, err: %s", err)
return
}
utils.IoBind(inConn, outConn, func(isSrcErr bool, err error) {
log.Printf("%s conn %s - %s - %s - %s released", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr())
log.Printf("conn %s released", ID)
utils.CloseConn(&inConn)
utils.CloseConn(&outConn)
}, func(i int, b bool) {}, 0)
log.Printf("%s conn %s - %s - %s - %s created", *s.cfg.Key, inConn.RemoteAddr(), inConn.LocalAddr(), outConn.LocalAddr(), outConn.RemoteAddr())
log.Printf("conn %s created", ID)
}

View File

@ -43,6 +43,9 @@ func (s *TunnelServer) Check() {
} else {
log.Fatalf("parent required")
}
if *s.cfg.Remote == "" {
log.Fatalf("remote required")
}
if s.cfg.CertBytes == nil || s.cfg.KeyBytes == nil {
log.Fatalf("cert and key file required")
}
@ -78,7 +81,7 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
}()
var outConn net.Conn
for {
outConn, err = s.GetOutConn()
outConn, err = s.GetOutConn("")
if err != nil {
utils.CloseConn(&outConn)
log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err)
@ -88,7 +91,6 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
break
}
}
utils.IoBind(inConn, outConn, func(isSrcErr bool, err error) {
utils.CloseConn(&outConn)
utils.CloseConn(&inConn)
@ -107,7 +109,7 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
func (s *TunnelServer) Clean() {
s.StopService()
}
func (s *TunnelServer) GetOutConn() (outConn net.Conn, err error) {
func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, err error) {
outConn, err = s.GetConn()
if err != nil {
log.Printf("connection err: %s", err)
@ -115,10 +117,24 @@ func (s *TunnelServer) GetOutConn() (outConn net.Conn, err error) {
}
keyBytes := []byte(*s.cfg.Key)
keyLength := uint16(len(keyBytes))
IDBytes := []byte(utils.NewUniqueID().String())
if id != "" {
IDBytes = []byte(id)
}
IDLength := uint16(len(IDBytes))
remoteAddr := []byte("tcp:" + *s.cfg.Remote)
if *s.cfg.IsUDP {
remoteAddr = []byte("udp:" + *s.cfg.Remote)
}
remoteAddrLength := uint16(len(remoteAddr))
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, CONN_SERVER)
binary.Write(pkg, binary.LittleEndian, keyLength)
binary.Write(pkg, binary.LittleEndian, keyBytes)
binary.Write(pkg, binary.LittleEndian, IDLength)
binary.Write(pkg, binary.LittleEndian, IDBytes)
binary.Write(pkg, binary.LittleEndian, remoteAddrLength)
binary.Write(pkg, binary.LittleEndian, remoteAddr)
_, err = outConn.Write(pkg.Bytes())
if err != nil {
log.Printf("write connection data err: %s ,retrying...", err)
@ -151,7 +167,7 @@ func (s *TunnelServer) UDPConnDeamon() {
RETRY:
if outConn == nil {
for {
outConn, err = s.GetOutConn()
outConn, err = s.GetOutConn("")
if err != nil {
cmdChn <- true
outConn = nil
@ -166,7 +182,7 @@ func (s *TunnelServer) UDPConnDeamon() {
outConn.Close()
}()
for {
srcAddrFromConn, body, err := utils.ReadUDPPacket(&outConn)
srcAddrFromConn, body, err := utils.ReadUDPPacket(bufio.NewReader(outConn))
if err == io.EOF || err == io.ErrUnexpectedEOF {
log.Printf("udp connection deamon exited, %s -> %s", outConn.LocalAddr(), outConn.RemoteAddr())
break
@ -195,14 +211,17 @@ func (s *TunnelServer) UDPConnDeamon() {
}
}
}
outConn.SetWriteDeadline(time.Now().Add(time.Second))
writer := bufio.NewWriter(outConn)
writer.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet))
err := writer.Flush()
if err != nil {
utils.CloseConn(&outConn)
outConn = nil
log.Printf("write udp packet to %s fail ,flush err:%s", *s.cfg.Parent, err)
log.Printf("write udp packet to %s fail ,flush err:%s ,retrying...", *s.cfg.Parent, err)
goto RETRY
}
outConn.SetWriteDeadline(time.Time{})
//log.Printf("write packet %v", *item.packet)
}
}()

View File

@ -120,7 +120,7 @@ func (s *UDP) OutToTCP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err erro
}()
log.Printf("conn %d created , local: %s", connKey, srcAddr.String())
for {
srcAddrFromConn, body, err := utils.ReadUDPPacket(&conn)
srcAddrFromConn, body, err := utils.ReadUDPPacket(bufio.NewReader(conn))
if err == io.EOF || err == io.ErrUnexpectedEOF {
//log.Printf("connection %d released", connKey)
s.p.Remove(fmt.Sprintf("%d", connKey))

View File

@ -1,7 +1,6 @@
package utils
import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
@ -272,8 +271,8 @@ func UDPPacket(srcAddr string, packet []byte) []byte {
binary.Write(pkg, binary.LittleEndian, packet)
return pkg.Bytes()
}
func ReadUDPPacket(conn *net.Conn) (srcAddr string, packet []byte, err error) {
reader := bufio.NewReader(*conn)
func ReadUDPPacket(reader io.Reader) (srcAddr string, packet []byte, err error) {
// reader := bufio.NewReader(_reader)
var addrLength uint16
var bodyLength uint16
err = binary.Read(reader, binary.LittleEndian, &addrLength)

264
utils/id.go Normal file
View File

@ -0,0 +1,264 @@
// Package xid is a globally unique id generator suited for web scale
//
// Xid is using Mongo Object ID algorithm to generate globally unique ids:
// https://docs.mongodb.org/manual/reference/object-id/
//
// - 4-byte value representing the seconds since the Unix epoch,
// - 3-byte machine identifier,
// - 2-byte process id, and
// - 3-byte counter, starting with a random value.
//
// The binary representation of the id is compatible with Mongo 12 bytes Object IDs.
// The string representation is using base32 hex (w/o padding) for better space efficiency
// when stored in that form (20 bytes). The hex variant of base32 is used to retain the
// sortable property of the id.
//
// Xid doesn't use base64 because case sensitivity and the 2 non alphanum chars may be an
// issue when transported as a string between various systems. Base36 wasn't retained either
// because 1/ it's not standard 2/ the resulting size is not predictable (not bit aligned)
// and 3/ it would not remain sortable. To validate a base32 `xid`, expect a 20 chars long,
// all lowercase sequence of `a` to `v` letters and `0` to `9` numbers (`[0-9a-v]{20}`).
//
// UUID is 16 bytes (128 bits), snowflake is 8 bytes (64 bits), xid stands in between
// with 12 bytes with a more compact string representation ready for the web and no
// required configuration or central generation server.
//
// Features:
//
// - Size: 12 bytes (96 bits), smaller than UUID, larger than snowflake
// - Base32 hex encoded by default (16 bytes storage when transported as printable string)
// - Non configured, you don't need set a unique machine and/or data center id
// - K-ordered
// - Embedded time with 1 second precision
// - Unicity guaranted for 16,777,216 (24 bits) unique ids per second and per host/process
//
// Best used with xlog's RequestIDHandler (https://godoc.org/github.com/rs/xlog#RequestIDHandler).
//
// References:
//
// - http://www.slideshare.net/davegardnerisme/unique-id-generation-in-distributed-systems
// - https://en.wikipedia.org/wiki/Universally_unique_identifier
// - https://blog.twitter.com/2010/announcing-snowflake
package utils
import (
"crypto/md5"
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"os"
"sync/atomic"
"time"
)
// Code inspired from mgo/bson ObjectId
// ID represents a unique request id
type ID [rawLen]byte
const (
encodedLen = 20 // string encoded len
decodedLen = 15 // len after base32 decoding with the padded data
rawLen = 12 // binary raw len
// encoding stores a custom version of the base32 encoding with lower case
// letters.
encoding = "0123456789abcdefghijklmnopqrstuv"
)
// ErrInvalidID is returned when trying to unmarshal an invalid ID
var ErrInvalidID = errors.New("xid: invalid ID")
// objectIDCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
// This id is initialized with a random value.
var objectIDCounter = randInt()
// machineId stores machine id generated once and used in subsequent calls
// to NewObjectId function.
var machineID = readMachineID()
// pid stores the current process id
var pid = os.Getpid()
// dec is the decoding map for base32 encoding
var dec [256]byte
func init() {
for i := 0; i < len(dec); i++ {
dec[i] = 0xFF
}
for i := 0; i < len(encoding); i++ {
dec[encoding[i]] = byte(i)
}
}
// readMachineId generates machine id and puts it into the machineId global
// variable. If this function fails to get the hostname, it will cause
// a runtime error.
func readMachineID() []byte {
id := make([]byte, 3)
if hostname, err := os.Hostname(); err == nil {
hw := md5.New()
hw.Write([]byte(hostname))
copy(id, hw.Sum(nil))
} else {
// Fallback to rand number if machine id can't be gathered
if _, randErr := rand.Reader.Read(id); randErr != nil {
panic(fmt.Errorf("xid: cannot get hostname nor generate a random number: %v; %v", err, randErr))
}
}
return id
}
// randInt generates a random uint32
func randInt() uint32 {
b := make([]byte, 3)
if _, err := rand.Reader.Read(b); err != nil {
panic(fmt.Errorf("xid: cannot generate random number: %v;", err))
}
return uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2])
}
// NewUniqueID generates a globaly unique ID
func NewUniqueID() ID {
var id ID
// Timestamp, 4 bytes, big endian
binary.BigEndian.PutUint32(id[:], uint32(time.Now().Unix()))
// Machine, first 3 bytes of md5(hostname)
id[4] = machineID[0]
id[5] = machineID[1]
id[6] = machineID[2]
// Pid, 2 bytes, specs don't specify endianness, but we use big endian.
id[7] = byte(pid >> 8)
id[8] = byte(pid)
// Increment, 3 bytes, big endian
i := atomic.AddUint32(&objectIDCounter, 1)
id[9] = byte(i >> 16)
id[10] = byte(i >> 8)
id[11] = byte(i)
return id
}
// FromString reads an ID from its string representation
func FromString(id string) (ID, error) {
i := &ID{}
err := i.UnmarshalText([]byte(id))
return *i, err
}
// String returns a base32 hex lowercased with no padding representation of the id (char set is 0-9, a-v).
func (id ID) String() string {
text := make([]byte, encodedLen)
encode(text, id[:])
return string(text)
}
// MarshalText implements encoding/text TextMarshaler interface
func (id ID) MarshalText() ([]byte, error) {
text := make([]byte, encodedLen)
encode(text, id[:])
return text, nil
}
// encode by unrolling the stdlib base32 algorithm + removing all safe checks
func encode(dst, id []byte) {
dst[0] = encoding[id[0]>>3]
dst[1] = encoding[(id[1]>>6)&0x1F|(id[0]<<2)&0x1F]
dst[2] = encoding[(id[1]>>1)&0x1F]
dst[3] = encoding[(id[2]>>4)&0x1F|(id[1]<<4)&0x1F]
dst[4] = encoding[id[3]>>7|(id[2]<<1)&0x1F]
dst[5] = encoding[(id[3]>>2)&0x1F]
dst[6] = encoding[id[4]>>5|(id[3]<<3)&0x1F]
dst[7] = encoding[id[4]&0x1F]
dst[8] = encoding[id[5]>>3]
dst[9] = encoding[(id[6]>>6)&0x1F|(id[5]<<2)&0x1F]
dst[10] = encoding[(id[6]>>1)&0x1F]
dst[11] = encoding[(id[7]>>4)&0x1F|(id[6]<<4)&0x1F]
dst[12] = encoding[id[8]>>7|(id[7]<<1)&0x1F]
dst[13] = encoding[(id[8]>>2)&0x1F]
dst[14] = encoding[(id[9]>>5)|(id[8]<<3)&0x1F]
dst[15] = encoding[id[9]&0x1F]
dst[16] = encoding[id[10]>>3]
dst[17] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
dst[18] = encoding[(id[11]>>1)&0x1F]
dst[19] = encoding[(id[11]<<4)&0x1F]
}
// UnmarshalText implements encoding/text TextUnmarshaler interface
func (id *ID) UnmarshalText(text []byte) error {
if len(text) != encodedLen {
return ErrInvalidID
}
for _, c := range text {
if dec[c] == 0xFF {
return ErrInvalidID
}
}
decode(id, text)
return nil
}
// decode by unrolling the stdlib base32 algorithm + removing all safe checks
func decode(id *ID, src []byte) {
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[3] = dec[src[4]]<<7 | dec[src[5]]<<2 | dec[src[6]]>>3
id[4] = dec[src[6]]<<5 | dec[src[7]]
id[5] = dec[src[8]]<<3 | dec[src[9]]>>2
id[6] = dec[src[9]]<<6 | dec[src[10]]<<1 | dec[src[11]]>>4
id[7] = dec[src[11]]<<4 | dec[src[12]]>>1
id[8] = dec[src[12]]<<7 | dec[src[13]]<<2 | dec[src[14]]>>3
id[9] = dec[src[14]]<<5 | dec[src[15]]
id[10] = dec[src[16]]<<3 | dec[src[17]]>>2
id[11] = dec[src[17]]<<6 | dec[src[18]]<<1 | dec[src[19]]>>4
}
// Time returns the timestamp part of the id.
// It's a runtime error to call this method with an invalid id.
func (id ID) Time() time.Time {
// First 4 bytes of ObjectId is 32-bit big-endian seconds from epoch.
secs := int64(binary.BigEndian.Uint32(id[0:4]))
return time.Unix(secs, 0)
}
// Machine returns the 3-byte machine id part of the id.
// It's a runtime error to call this method with an invalid id.
func (id ID) Machine() []byte {
return id[4:7]
}
// Pid returns the process id part of the id.
// It's a runtime error to call this method with an invalid id.
func (id ID) Pid() uint16 {
return binary.BigEndian.Uint16(id[7:9])
}
// Counter returns the incrementing value part of the id.
// It's a runtime error to call this method with an invalid id.
func (id ID) Counter() int32 {
b := id[9:12]
// Counter is stored as big-endian 3-byte value
return int32(uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2]))
}
// Value implements the driver.Valuer interface.
func (id ID) Value() (driver.Value, error) {
b, err := id.MarshalText()
return string(b), err
}
// Scan implements the sql.Scanner interface.
func (id *ID) Scan(value interface{}) (err error) {
switch val := value.(type) {
case string:
return id.UnmarshalText([]byte(val))
case []byte:
return id.UnmarshalText(val)
default:
return fmt.Errorf("xid: scanning unsupported type: %T", value)
}
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
@ -11,6 +12,7 @@ import (
"net"
"net/url"
"strings"
"sync"
"time"
)
@ -460,3 +462,129 @@ func (op *OutPool) initPoolDeamon() {
}
}()
}
type HeartbeatData struct {
Data []byte
N int
Error error
}
type HeartbeatReadWriter struct {
conn *net.Conn
rchn chan HeartbeatData
l *sync.Mutex
dur int
errHandler func(err error, hb *HeartbeatReadWriter)
once *sync.Once
}
func NewHeartbeatReadWriter(conn *net.Conn, dur int, fn func(err error, hb *HeartbeatReadWriter)) (hrw HeartbeatReadWriter) {
hrw = HeartbeatReadWriter{
conn: conn,
l: &sync.Mutex{},
dur: dur,
rchn: make(chan HeartbeatData, 10000),
errHandler: fn,
once: &sync.Once{},
}
hrw.heartbeat()
hrw.reader()
return
}
func (rw *HeartbeatReadWriter) Close() {
CloseConn(rw.conn)
}
func (rw *HeartbeatReadWriter) reader() {
go func() {
//log.Printf("heartbeat read started")
for {
n, data, err := rw.read()
log.Printf("n:%d , data:%s ,err:%s", n, string(data), err)
if n >= 0 {
rw.rchn <- HeartbeatData{
Data: data,
Error: err,
N: n,
}
}
if err != nil {
break
}
}
//log.Printf("heartbeat read exited")
}()
}
func (rw *HeartbeatReadWriter) read() (n int, data []byte, err error) {
defer func() {
if err != nil {
rw.once.Do(func() {
rw.errHandler(err, rw)
})
}
}()
var typ uint8
err = binary.Read((*rw.conn), binary.LittleEndian, &typ)
if err != nil {
return
}
if typ == 0 {
// log.Printf("heartbeat revecived")
n = -1
return
}
var dataLength uint32
binary.Read((*rw.conn), binary.LittleEndian, &dataLength)
data = make([]byte, dataLength)
// log.Printf("dataLength:%d , data:%s", dataLength, string(data))
n, err = (*rw.conn).Read(data)
//log.Printf("n:%d , data:%s ,err:%s", n, string(data), err)
if err != nil {
return
}
return
}
func (rw *HeartbeatReadWriter) heartbeat() {
go func() {
//log.Printf("heartbeat started")
for {
if rw.conn == nil || *rw.conn == nil {
//log.Printf("heartbeat err: conn nil")
break
}
rw.l.Lock()
_, err := (*rw.conn).Write([]byte{0})
rw.l.Unlock()
if err != nil {
if rw.errHandler != nil {
//log.Printf("heartbeat err: %s", err)
rw.once.Do(func() {
rw.errHandler(err, rw)
})
break
}
} else {
// log.Printf("heartbeat send ok")
}
time.Sleep(time.Second * time.Duration(rw.dur))
}
//log.Printf("heartbeat exited")
}()
}
func (rw *HeartbeatReadWriter) Read(p []byte) (n int, err error) {
item := <-rw.rchn
//log.Println(item)
if item.N > 0 {
copy(p, item.Data)
}
return item.N, item.Error
}
func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) {
defer rw.l.Unlock()
rw.l.Lock()
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, uint8(1))
binary.Write(pkg, binary.LittleEndian, uint32(len(p)))
binary.Write(pkg, binary.LittleEndian, p)
n, err = (*rw.conn).Write(pkg.Bytes())
return
}