Merge branch 'master' into dev

This commit is contained in:
arraykeys@gmail.com
2017-10-23 16:35:23 +08:00
7 changed files with 366 additions and 152 deletions

View File

@ -6,7 +6,7 @@ fi
mkdir /tmp/proxy mkdir /tmp/proxy
cd /tmp/proxy cd /tmp/proxy
wget https://github.com/reddec/monexec/releases/download/v0.1.1/monexec_0.1.1_linux_amd64.tar.gz wget https://github.com/reddec/monexec/releases/download/v0.1.1/monexec_0.1.1_linux_amd64.tar.gz
wget https://github.com/snail007/goproxy/releases/download/v3.4/proxy-linux-amd64.tar.gz wget https://github.com/snail007/goproxy/releases/download/v3.3/proxy-linux-amd64.tar.gz
# install monexec # install monexec
tar zxvf monexec_0.1.1_linux_amd64.tar.gz tar zxvf monexec_0.1.1_linux_amd64.tar.gz

View File

@ -6,13 +6,14 @@ import "golang.org/x/crypto/ssh"
// t := tcp.Flag("tcp-timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int() // t := tcp.Flag("tcp-timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int()
const ( const (
TYPE_TCP = "tcp" TYPE_TCP = "tcp"
TYPE_UDP = "udp" TYPE_UDP = "udp"
TYPE_HTTP = "http" TYPE_HTTP = "http"
TYPE_TLS = "tls" TYPE_TLS = "tls"
CONN_CONTROL = uint8(1) CONN_CLIENT_CONTROL = uint8(1)
CONN_SERVER = uint8(2) CONN_SERVER_CONTROL = uint8(2)
CONN_CLIENT = uint8(3) CONN_SERVER = uint8(3)
CONN_CLIENT = uint8(4)
) )
type TunnelServerArgs struct { type TunnelServerArgs struct {
@ -27,6 +28,7 @@ type TunnelServerArgs struct {
Remote *string Remote *string
Timeout *int Timeout *int
Route *[]string Route *[]string
Mgr *TunnelServerManager
} }
type TunnelClientArgs struct { type TunnelClientArgs struct {
Parent *string Parent *string

View File

@ -2,7 +2,6 @@ package services
import ( import (
"bufio" "bufio"
"encoding/binary"
"log" "log"
"net" "net"
"proxy/utils" "proxy/utils"
@ -11,13 +10,15 @@ import (
) )
type ServerConn struct { type ServerConn struct {
ClientLocalAddr string //tcp:2.2.22:333@ID //ClientLocalAddr string //tcp:2.2.22:333@ID
Conn *net.Conn Conn *net.Conn
} }
type TunnelBridge struct { type TunnelBridge struct {
cfg TunnelBridgeArgs cfg TunnelBridgeArgs
serverConns utils.ConcurrentMap serverConns utils.ConcurrentMap
clientControlConns utils.ConcurrentMap clientControlConns utils.ConcurrentMap
cmServer utils.ConnManager
cmClient utils.ConnManager
} }
func NewTunnelBridge() Service { func NewTunnelBridge() Service {
@ -25,6 +26,8 @@ func NewTunnelBridge() Service {
cfg: TunnelBridgeArgs{}, cfg: TunnelBridgeArgs{},
serverConns: utils.NewConcurrentMap(), serverConns: utils.NewConcurrentMap(),
clientControlConns: utils.NewConcurrentMap(), clientControlConns: utils.NewConcurrentMap(),
cmServer: utils.NewConnManager(),
cmClient: utils.NewConnManager(),
} }
} }
@ -52,73 +55,27 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
//log.Printf("connection from %s ", inConn.RemoteAddr()) //log.Printf("connection from %s ", inConn.RemoteAddr())
reader := bufio.NewReader(inConn) reader := bufio.NewReader(inConn)
var err error
var connType uint8 var connType uint8
err = binary.Read(reader, binary.LittleEndian, &connType) err = utils.ReadPacket(reader, &connType)
if err != nil { if err != nil {
utils.CloseConn(&inConn) log.Printf("read error,ERR:%s", err)
return return
} }
//log.Printf("conn type %d", connType)
var key, clientLocalAddr, ID string
var connTypeStrMap = map[uint8]string{CONN_SERVER: "server", CONN_CLIENT: "client", CONN_CONTROL: "client"}
var keyLength uint16
err = binary.Read(reader, binary.LittleEndian, &keyLength)
if err != nil {
return
}
_key := make([]byte, keyLength)
n, err := reader.Read(_key)
if err != nil {
return
}
if n != int(keyLength) {
return
}
key = string(_key)
if connType != CONN_CONTROL {
var IDLength uint16
err = binary.Read(reader, binary.LittleEndian, &IDLength)
if err != nil {
return
}
_id := make([]byte, IDLength)
n, err := reader.Read(_id)
if err != nil {
return
}
if n != int(IDLength) {
return
}
ID = string(_id)
if connType == CONN_SERVER {
var addrLength uint16
err = binary.Read(reader, binary.LittleEndian, &addrLength)
if err != nil {
return
}
_addr := make([]byte, addrLength)
n, err = reader.Read(_addr)
if err != nil {
return
}
if n != int(addrLength) {
return
}
clientLocalAddr = string(_addr)
}
}
log.Printf("connection from %s , key: %s , id: %s", connTypeStrMap[connType], key, ID)
switch connType { switch connType {
case CONN_SERVER: case CONN_SERVER:
addr := clientLocalAddr + "@" + ID var key, ID, clientLocalAddr, serverID string
err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
packet := utils.BuildPacketData(ID, clientLocalAddr, serverID)
log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID)
//addr := clientLocalAddr + "@" + ID
s.serverConns.Set(ID, ServerConn{ s.serverConns.Set(ID, ServerConn{
Conn: &inConn, Conn: &inConn,
ClientLocalAddr: addr,
}) })
for { for {
item, ok := s.clientControlConns.Get(key) item, ok := s.clientControlConns.Get(key)
@ -128,17 +85,26 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
continue continue
} }
(*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3)) (*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err := (*item.(*net.Conn)).Write([]byte(addr)) _, err := (*item.(*net.Conn)).Write(packet)
(*item.(*net.Conn)).SetWriteDeadline(time.Time{}) (*item.(*net.Conn)).SetWriteDeadline(time.Time{})
if err != nil { if err != nil {
log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err) log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} else { } else {
s.cmServer.Add(serverID, ID, &inConn)
break break
} }
} }
case CONN_CLIENT: case CONN_CLIENT:
var key, ID, serverID string
err = utils.ReadPacketData(reader, &key, &ID, &serverID)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID)
serverConnItem, ok := s.serverConns.Get(ID) serverConnItem, ok := s.serverConns.Get(ID)
if !ok { if !ok {
inConn.Close() inConn.Close()
@ -147,15 +113,24 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
} }
serverConn := serverConnItem.(ServerConn).Conn serverConn := serverConnItem.(ServerConn).Conn
utils.IoBind(*serverConn, inConn, func(err error) { utils.IoBind(*serverConn, inConn, func(err error) {
(*serverConn).Close() (*serverConn).Close()
utils.CloseConn(&inConn) utils.CloseConn(&inConn)
s.serverConns.Remove(ID) s.serverConns.Remove(ID)
s.cmClient.RemoveOne(key, ID)
s.cmServer.RemoveOne(serverID, ID)
log.Printf("conn %s released", ID) log.Printf("conn %s released", ID)
}, func(i int, b bool) {}, 0) }, func(i int, b bool) {}, 0)
s.cmClient.Add(key, ID, &inConn)
log.Printf("conn %s created", ID) log.Printf("conn %s created", ID)
case CONN_CONTROL: case CONN_CLIENT_CONTROL:
var key string
err = utils.ReadPacketData(reader, &key)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
log.Printf("client control connection, key: %s", key)
if s.clientControlConns.Has(key) { if s.clientControlConns.Has(key) {
item, _ := s.clientControlConns.Get(key) item, _ := s.clientControlConns.Get(key)
(*item.(*net.Conn)).Close() (*item.(*net.Conn)).Close()
@ -168,14 +143,59 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
_, err = inConn.Read(b) _, err = inConn.Read(b)
if err != nil { if err != nil {
inConn.Close() inConn.Close()
s.serverConns.Remove(ID)
log.Printf("%s control conn from client released", key) log.Printf("%s control conn from client released", key)
s.cmClient.Remove(key)
break break
} else { } else {
//log.Printf("%s heartbeat from client", key) //log.Printf("%s heartbeat from client", key)
} }
} }
}() }()
case CONN_SERVER_CONTROL:
var serverID string
err = utils.ReadPacketData(reader, &serverID)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
log.Printf("server control connection, id: %s", serverID)
writeDie := make(chan bool)
readDie := make(chan bool)
go func() {
for {
inConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err = inConn.Write([]byte{0x00})
inConn.SetWriteDeadline(time.Time{})
if err != nil {
log.Printf("control connection write err %s", err)
break
}
time.Sleep(time.Second * 3)
}
close(writeDie)
}()
go func() {
for {
signal := make([]byte, 1)
inConn.SetReadDeadline(time.Now().Add(time.Second * 10))
_, err := inConn.Read(signal)
inConn.SetReadDeadline(time.Time{})
if err != nil {
log.Printf("control connection read err: %s", err)
break
} else {
// log.Printf("heartbeat from server ,id:%s", ID)
}
}
close(readDie)
}()
select {
case <-readDie:
case <-writeDie:
}
utils.CloseConn(&inConn)
s.cmServer.Remove(serverID)
log.Printf("server control conn %s released", serverID)
} }
}) })
if err != nil { if err != nil {

View File

@ -1,25 +1,24 @@
package services package services
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"proxy/utils" "proxy/utils"
"strings"
"time" "time"
) )
type TunnelClient struct { type TunnelClient struct {
cfg TunnelClientArgs cfg TunnelClientArgs
cm utils.ConnManager
} }
func NewTunnelClient() Service { func NewTunnelClient() Service {
return &TunnelClient{ return &TunnelClient{
cfg: TunnelClientArgs{}, cfg: TunnelClientArgs{},
cm: utils.NewConnManager(),
} }
} }
@ -37,14 +36,20 @@ func (s *TunnelClient) CheckArgs() {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile) s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
} }
func (s *TunnelClient) StopService() { func (s *TunnelClient) StopService() {
s.cm.RemoveAll()
} }
func (s *TunnelClient) Start(args interface{}) (err error) { func (s *TunnelClient) Start(args interface{}) (err error) {
s.cfg = args.(TunnelClientArgs) s.cfg = args.(TunnelClientArgs)
s.CheckArgs() s.CheckArgs()
s.InitService() s.InitService()
log.Printf("proxy on tunnel client mode") log.Printf("proxy on tunnel client mode")
var ctrlConn net.Conn
for { for {
ctrlConn, err := s.GetInConn(CONN_CONTROL, "") //close all conn
s.cm.Remove(*s.cfg.Key)
utils.CloseConn(&ctrlConn)
ctrlConn, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key)
if err != nil { if err != nil {
log.Printf("control connection err: %s, retrying...", err) log.Printf("control connection err: %s, retrying...", err)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@ -53,6 +58,9 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
} }
go func() { go func() {
for { for {
if ctrlConn == nil {
break
}
ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3)) ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err = ctrlConn.Write([]byte{0x00}) _, err = ctrlConn.Write([]byte{0x00})
ctrlConn.SetWriteDeadline(time.Time{}) ctrlConn.SetWriteDeadline(time.Time{})
@ -65,23 +73,20 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
} }
}() }()
for { for {
signal := make([]byte, 50) var ID, clientLocalAddr, serverID string
n, err := ctrlConn.Read(signal) err = utils.ReadPacketData(ctrlConn, &ID, &clientLocalAddr, &serverID)
if err != nil { if err != nil {
utils.CloseConn(&ctrlConn) utils.CloseConn(&ctrlConn)
log.Printf("read connection signal err: %s, retrying...", err) log.Printf("read connection signal err: %s, retrying...", err)
break break
} }
addr := string(signal[:n]) log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr)
log.Printf("signal revecived:%s", addr) protocol := clientLocalAddr[:3]
protocol := addr[:3] localAddr := clientLocalAddr[4:]
atIndex := strings.Index(addr, "@")
ID := addr[atIndex+1:]
localAddr := addr[4:atIndex]
if protocol == "udp" { if protocol == "udp" {
go s.ServeUDP(localAddr, ID) go s.ServeUDP(localAddr, ID, serverID)
} else { } else {
go s.ServeConn(localAddr, ID) go s.ServeConn(localAddr, ID, serverID)
} }
} }
} }
@ -89,25 +94,13 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
func (s *TunnelClient) Clean() { func (s *TunnelClient) Clean() {
s.StopService() s.StopService()
} }
func (s *TunnelClient) GetInConn(typ uint8, ID string) (outConn net.Conn, err error) { func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, err error) {
outConn, err = s.GetConn() outConn, err = s.GetConn()
if err != nil { if err != nil {
err = fmt.Errorf("connection err: %s", err) err = fmt.Errorf("connection err: %s", err)
return return
} }
keyBytes := []byte(*s.cfg.Key) _, err = outConn.Write(utils.BuildPacket(typ, data...))
keyLength := uint16(len(keyBytes))
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, typ)
binary.Write(pkg, binary.LittleEndian, keyLength)
binary.Write(pkg, binary.LittleEndian, keyBytes)
if ID != "" {
IDBytes := []byte(ID)
IDLength := uint16(len(IDBytes))
binary.Write(pkg, binary.LittleEndian, IDLength)
binary.Write(pkg, binary.LittleEndian, IDBytes)
}
_, err = outConn.Write(pkg.Bytes())
if err != nil { if err != nil {
err = fmt.Errorf("write connection data err: %s ,retrying...", err) err = fmt.Errorf("write connection data err: %s ,retrying...", err)
utils.CloseConn(&outConn) utils.CloseConn(&outConn)
@ -123,12 +116,13 @@ func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
} }
return return
} }
func (s *TunnelClient) ServeUDP(localAddr, ID string) { func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) {
var inConn net.Conn var inConn net.Conn
var err error var err error
// for { // for {
for { for {
inConn, err = s.GetInConn(CONN_CLIENT, ID) s.cm.RemoveOne(*s.cfg.Key, ID)
inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
if err != nil { if err != nil {
utils.CloseConn(&inConn) utils.CloseConn(&inConn)
log.Printf("connection err: %s, retrying...", err) log.Printf("connection err: %s, retrying...", err)
@ -138,13 +132,10 @@ func (s *TunnelClient) ServeUDP(localAddr, ID string) {
break break
} }
} }
s.cm.Add(*s.cfg.Key, ID, &inConn)
log.Printf("conn %s created", ID) log.Printf("conn %s created", ID)
// hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) {
// log.Printf("hw err %s", err)
// hw.Close()
// })
for { for {
// srcAddr, body, err := utils.ReadUDPPacket(&hw)
srcAddr, body, err := utils.ReadUDPPacket(inConn) srcAddr, body, err := utils.ReadUDPPacket(inConn)
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
log.Printf("connection %s released", ID) log.Printf("connection %s released", ID)
@ -197,11 +188,11 @@ func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr str
} }
//log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs) //log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs)
} }
func (s *TunnelClient) ServeConn(localAddr, ID string) { func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) {
var inConn, outConn net.Conn var inConn, outConn net.Conn
var err error var err error
for { for {
inConn, err = s.GetInConn(CONN_CLIENT, ID) inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
if err != nil { if err != nil {
utils.CloseConn(&inConn) utils.CloseConn(&inConn)
log.Printf("connection err: %s, retrying...", err) log.Printf("connection err: %s, retrying...", err)
@ -236,6 +227,8 @@ func (s *TunnelClient) ServeConn(localAddr, ID string) {
log.Printf("conn %s released", ID) log.Printf("conn %s released", ID)
utils.CloseConn(&inConn) utils.CloseConn(&inConn)
utils.CloseConn(&outConn) utils.CloseConn(&outConn)
s.cm.RemoveOne(*s.cfg.Key, ID)
}, func(i int, b bool) {}, 0) }, func(i int, b bool) {}, 0)
s.cm.Add(*s.cfg.Key, ID, &inConn)
log.Printf("conn %s created", ID) log.Printf("conn %s created", ID)
} }

View File

@ -1,9 +1,7 @@
package services package services
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -22,24 +20,33 @@ type TunnelServer struct {
} }
type TunnelServerManager struct { type TunnelServerManager struct {
cfg TunnelServerArgs cfg TunnelServerArgs
udpChn chan UDPItem udpChn chan UDPItem
sc utils.ServerChannel sc utils.ServerChannel
serverID string
cm utils.ConnManager
} }
func NewTunnelServerManager() Service { func NewTunnelServerManager() Service {
return &TunnelServerManager{ return &TunnelServerManager{
cfg: TunnelServerArgs{}, cfg: TunnelServerArgs{},
udpChn: make(chan UDPItem, 50000), udpChn: make(chan UDPItem, 50000),
serverID: utils.Uniqueid(),
cm: utils.NewConnManager(),
} }
} }
func (s *TunnelServerManager) Start(args interface{}) (err error) { func (s *TunnelServerManager) Start(args interface{}) (err error) {
s.cfg = args.(TunnelServerArgs) s.cfg = args.(TunnelServerArgs)
s.CheckArgs()
if *s.cfg.Parent != "" { if *s.cfg.Parent != "" {
log.Printf("use tls parent %s", *s.cfg.Parent) log.Printf("use tls parent %s", *s.cfg.Parent)
} else { } else {
log.Fatalf("parent required") log.Fatalf("parent required")
} }
s.InitService()
log.Printf("server id: %s", s.serverID)
//log.Printf("route:%v", *s.cfg.Route) //log.Printf("route:%v", *s.cfg.Route)
for _, _info := range *s.cfg.Route { for _, _info := range *s.cfg.Route {
IsUDP := *s.cfg.IsUDP IsUDP := *s.cfg.IsUDP
@ -71,6 +78,7 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) {
Remote: &remote, Remote: &remote,
Key: &KEY, Key: &KEY,
Timeout: s.cfg.Timeout, Timeout: s.cfg.Timeout,
Mgr: s,
}) })
if err != nil { if err != nil {
@ -80,7 +88,95 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) {
return return
} }
func (s *TunnelServerManager) Clean() { func (s *TunnelServerManager) Clean() {
s.StopService()
}
func (s *TunnelServerManager) StopService() {
s.cm.RemoveAll()
}
func (s *TunnelServerManager) CheckArgs() {
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
log.Fatalf("cert and key file required")
}
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
}
func (s *TunnelServerManager) InitService() {
s.InitControlDeamon()
}
func (s *TunnelServerManager) InitControlDeamon() {
go func() {
var ctrlConn net.Conn
var ID string
for {
//close all connection
s.cm.Remove(ID)
utils.CloseConn(&ctrlConn)
ctrlConn, ID, err := s.GetOutConn(CONN_SERVER_CONTROL)
if err != nil {
log.Printf("control connection err: %s, retrying...", err)
time.Sleep(time.Second * 3)
utils.CloseConn(&ctrlConn)
continue
}
log.Printf("control connection created,id:%s", ID)
writeDie := make(chan bool)
readDie := make(chan bool)
go func() {
for {
ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err = ctrlConn.Write([]byte{0x00})
ctrlConn.SetWriteDeadline(time.Time{})
if err != nil {
log.Printf("control connection write err %s", err)
break
}
time.Sleep(time.Second * 3)
}
close(writeDie)
}()
go func() {
for {
signal := make([]byte, 1)
ctrlConn.SetReadDeadline(time.Now().Add(time.Second * 10))
_, err := ctrlConn.Read(signal)
ctrlConn.SetReadDeadline(time.Time{})
if err != nil {
log.Printf("control connection read err: %s", err)
break
} else {
// log.Printf("heartbeat from bridge")
}
}
close(readDie)
}()
select {
case <-readDie:
case <-writeDie:
}
}
}()
}
func (s *TunnelServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) {
outConn, err = s.GetConn()
if err != nil {
log.Printf("connection err: %s", err)
return
}
ID = s.serverID
_, err = outConn.Write(utils.BuildPacket(typ, s.serverID))
if err != nil {
log.Printf("write connection data err: %s ,retrying...", err)
utils.CloseConn(&outConn)
return
}
return
}
func (s *TunnelServerManager) GetConn() (conn net.Conn, err error) {
var _conn tls.Conn
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
if err == nil {
conn = net.Conn(&_conn)
}
return
} }
func NewTunnelServer() Service { func NewTunnelServer() Service {
return &TunnelServer{ return &TunnelServer{
@ -102,13 +198,8 @@ func (s *TunnelServer) CheckArgs() {
if *s.cfg.Remote == "" { if *s.cfg.Remote == "" {
log.Fatalf("remote required") log.Fatalf("remote required")
} }
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
log.Fatalf("cert and key file required")
}
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
}
func (s *TunnelServer) StopService() {
} }
func (s *TunnelServer) Start(args interface{}) (err error) { func (s *TunnelServer) Start(args interface{}) (err error) {
s.cfg = args.(TunnelServerArgs) s.cfg = args.(TunnelServerArgs)
s.CheckArgs() s.CheckArgs()
@ -138,7 +229,7 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
var outConn net.Conn var outConn net.Conn
var ID string var ID string
for { for {
outConn, ID, err = s.GetOutConn("") outConn, ID, err = s.GetOutConn(CONN_SERVER)
if err != nil { if err != nil {
utils.CloseConn(&outConn) utils.CloseConn(&outConn)
log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err) log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err)
@ -151,9 +242,11 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
utils.IoBind(inConn, outConn, func(err error) { utils.IoBind(inConn, outConn, func(err error) {
utils.CloseConn(&outConn) utils.CloseConn(&outConn)
utils.CloseConn(&inConn) utils.CloseConn(&inConn)
s.cfg.Mgr.cm.RemoveOne(s.cfg.Mgr.serverID, ID)
log.Printf("%s conn %s released", *s.cfg.Key, ID) log.Printf("%s conn %s released", *s.cfg.Key, ID)
}, func(i int, b bool) {}, 0) }, func(i int, b bool) {}, 0)
//add conn
s.cfg.Mgr.cm.Add(s.cfg.Mgr.serverID, ID, &inConn)
log.Printf("%s conn %s created", *s.cfg.Key, ID) log.Printf("%s conn %s created", *s.cfg.Key, ID)
}) })
if err != nil { if err != nil {
@ -164,37 +257,20 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
return return
} }
func (s *TunnelServer) Clean() { func (s *TunnelServer) Clean() {
s.StopService()
} }
func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, ID string, err error) { func (s *TunnelServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) {
outConn, err = s.GetConn() outConn, err = s.GetConn()
if err != nil { if err != nil {
log.Printf("connection err: %s", err) log.Printf("connection err: %s", err)
return return
} }
keyBytes := []byte(*s.cfg.Key) remoteAddr := "tcp:" + *s.cfg.Remote
keyLength := uint16(len(keyBytes))
ID = utils.Uniqueid()
IDBytes := []byte(ID)
if id != "" {
ID = id
IDBytes = []byte(id)
}
IDLength := uint16(len(IDBytes))
remoteAddr := []byte("tcp:" + *s.cfg.Remote)
if *s.cfg.IsUDP { if *s.cfg.IsUDP {
remoteAddr = []byte("udp:" + *s.cfg.Remote) remoteAddr = "udp:" + *s.cfg.Remote
} }
remoteAddrLength := uint16(len(remoteAddr)) ID = utils.Uniqueid()
pkg := new(bytes.Buffer) _, err = outConn.Write(utils.BuildPacket(typ, *s.cfg.Key, ID, remoteAddr, s.cfg.Mgr.serverID))
binary.Write(pkg, binary.LittleEndian, CONN_SERVER)
binary.Write(pkg, binary.LittleEndian, keyLength)
binary.Write(pkg, binary.LittleEndian, keyBytes)
binary.Write(pkg, binary.LittleEndian, IDLength)
binary.Write(pkg, binary.LittleEndian, IDBytes)
binary.Write(pkg, binary.LittleEndian, remoteAddrLength)
binary.Write(pkg, binary.LittleEndian, remoteAddr)
_, err = outConn.Write(pkg.Bytes())
if err != nil { if err != nil {
log.Printf("write connection data err: %s ,retrying...", err) log.Printf("write connection data err: %s ,retrying...", err)
utils.CloseConn(&outConn) utils.CloseConn(&outConn)
@ -227,7 +303,7 @@ func (s *TunnelServer) UDPConnDeamon() {
RETRY: RETRY:
if outConn == nil { if outConn == nil {
for { for {
outConn, ID, err = s.GetOutConn("") outConn, ID, err = s.GetOutConn(CONN_SERVER)
if err != nil { if err != nil {
// cmdChn <- true // cmdChn <- true
outConn = nil outConn = nil

View File

@ -315,6 +315,68 @@ func Uniqueid() string {
s := fmt.Sprintf("%d", src.Int63()) s := fmt.Sprintf("%d", src.Int63())
return s[len(s)-5:len(s)-1] + fmt.Sprintf("%d", uint64(time.Now().UnixNano()))[8:] return s[len(s)-5:len(s)-1] + fmt.Sprintf("%d", uint64(time.Now().UnixNano()))[8:]
} }
func ReadData(r io.Reader) (data string, err error) {
var len uint16
err = binary.Read(r, binary.LittleEndian, &len)
if err != nil {
return
}
var n int
_data := make([]byte, len)
n, err = r.Read(_data)
if err != nil {
return
}
if n != int(len) {
err = fmt.Errorf("error data len")
return
}
data = string(_data)
return
}
func ReadPacketData(r io.Reader, data ...*string) (err error) {
for _, d := range data {
*d, err = ReadData(r)
if err != nil {
return
}
}
return
}
func ReadPacket(r io.Reader, typ *uint8, data ...*string) (err error) {
var connType uint8
err = binary.Read(r, binary.LittleEndian, &connType)
if err != nil {
return
}
*typ = connType
for _, d := range data {
*d, err = ReadData(r)
if err != nil {
return
}
}
return
}
func BuildPacket(typ uint8, data ...string) []byte {
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, typ)
for _, d := range data {
bytes := []byte(d)
binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
binary.Write(pkg, binary.LittleEndian, bytes)
}
return pkg.Bytes()
}
func BuildPacketData(data ...string) []byte {
pkg := new(bytes.Buffer)
for _, d := range data {
bytes := []byte(d)
binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
binary.Write(pkg, binary.LittleEndian, bytes)
}
return pkg.Bytes()
}
func SubStr(str string, start, end int) string { func SubStr(str string, start, end int) string {
if len(str) == 0 { if len(str) == 0 {
return "" return ""

View File

@ -617,3 +617,64 @@ func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) {
} }
return return
} }
type ConnManager struct {
pool ConcurrentMap
l *sync.Mutex
}
func NewConnManager() ConnManager {
cm := ConnManager{
pool: NewConcurrentMap(),
l: &sync.Mutex{},
}
return cm
}
func (cm *ConnManager) Add(key, ID string, conn *net.Conn) {
cm.pool.Upsert(key, nil, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} {
var conns ConcurrentMap
if !exist {
conns = NewConcurrentMap()
} else {
conns = valueInMap.(ConcurrentMap)
}
if conns.Has(ID) {
v, _ := conns.Get(ID)
(*v.(*net.Conn)).Close()
}
conns.Set(ID, conn)
log.Printf("%s conn added", key)
return conns
})
}
func (cm *ConnManager) Remove(key string) {
var conns ConcurrentMap
if v, ok := cm.pool.Get(key); ok {
conns = v.(ConcurrentMap)
conns.IterCb(func(key string, v interface{}) {
CloseConn(v.(*net.Conn))
})
log.Printf("%s conns closed", key)
}
cm.pool.Remove(key)
}
func (cm *ConnManager) RemoveOne(key string, ID string) {
defer cm.l.Unlock()
cm.l.Lock()
var conns ConcurrentMap
if v, ok := cm.pool.Get(key); ok {
conns = v.(ConcurrentMap)
if conns.Has(ID) {
v, _ := conns.Get(ID)
(*v.(*net.Conn)).Close()
conns.Remove(ID)
cm.pool.Set(key, conns)
log.Printf("%s %s conn closed", key, ID)
}
}
}
func (cm *ConnManager) RemoveAll() {
for _, k := range cm.pool.Keys() {
cm.Remove(k)
}
}