@ -6,13 +6,14 @@ import "golang.org/x/crypto/ssh"
|
||||
// t := tcp.Flag("tcp-timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int()
|
||||
|
||||
const (
|
||||
TYPE_TCP = "tcp"
|
||||
TYPE_UDP = "udp"
|
||||
TYPE_HTTP = "http"
|
||||
TYPE_TLS = "tls"
|
||||
CONN_CONTROL = uint8(1)
|
||||
CONN_SERVER = uint8(2)
|
||||
CONN_CLIENT = uint8(3)
|
||||
TYPE_TCP = "tcp"
|
||||
TYPE_UDP = "udp"
|
||||
TYPE_HTTP = "http"
|
||||
TYPE_TLS = "tls"
|
||||
CONN_CLIENT_CONTROL = uint8(1)
|
||||
CONN_SERVER_CONTROL = uint8(2)
|
||||
CONN_SERVER = uint8(3)
|
||||
CONN_CLIENT = uint8(4)
|
||||
)
|
||||
|
||||
type TunnelServerArgs struct {
|
||||
@ -27,6 +28,7 @@ type TunnelServerArgs struct {
|
||||
Remote *string
|
||||
Timeout *int
|
||||
Route *[]string
|
||||
Mgr *TunnelServerManager
|
||||
}
|
||||
type TunnelClientArgs struct {
|
||||
Parent *string
|
||||
|
||||
@ -2,7 +2,6 @@ package services
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"net"
|
||||
"proxy/utils"
|
||||
@ -11,13 +10,15 @@ import (
|
||||
)
|
||||
|
||||
type ServerConn struct {
|
||||
ClientLocalAddr string //tcp:2.2.22:333@ID
|
||||
Conn *net.Conn
|
||||
//ClientLocalAddr string //tcp:2.2.22:333@ID
|
||||
Conn *net.Conn
|
||||
}
|
||||
type TunnelBridge struct {
|
||||
cfg TunnelBridgeArgs
|
||||
serverConns utils.ConcurrentMap
|
||||
clientControlConns utils.ConcurrentMap
|
||||
cmServer utils.ConnManager
|
||||
cmClient utils.ConnManager
|
||||
}
|
||||
|
||||
func NewTunnelBridge() Service {
|
||||
@ -25,6 +26,8 @@ func NewTunnelBridge() Service {
|
||||
cfg: TunnelBridgeArgs{},
|
||||
serverConns: utils.NewConcurrentMap(),
|
||||
clientControlConns: utils.NewConcurrentMap(),
|
||||
cmServer: utils.NewConnManager(),
|
||||
cmClient: utils.NewConnManager(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,73 +55,27 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
|
||||
//log.Printf("connection from %s ", inConn.RemoteAddr())
|
||||
|
||||
reader := bufio.NewReader(inConn)
|
||||
var err error
|
||||
var connType uint8
|
||||
err = binary.Read(reader, binary.LittleEndian, &connType)
|
||||
err = utils.ReadPacket(reader, &connType)
|
||||
if err != nil {
|
||||
utils.CloseConn(&inConn)
|
||||
log.Printf("read error,ERR:%s", err)
|
||||
return
|
||||
}
|
||||
//log.Printf("conn type %d", connType)
|
||||
|
||||
var key, clientLocalAddr, ID string
|
||||
var connTypeStrMap = map[uint8]string{CONN_SERVER: "server", CONN_CLIENT: "client", CONN_CONTROL: "client"}
|
||||
var keyLength uint16
|
||||
err = binary.Read(reader, binary.LittleEndian, &keyLength)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_key := make([]byte, keyLength)
|
||||
n, err := reader.Read(_key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != int(keyLength) {
|
||||
return
|
||||
}
|
||||
key = string(_key)
|
||||
|
||||
if connType != CONN_CONTROL {
|
||||
var IDLength uint16
|
||||
err = binary.Read(reader, binary.LittleEndian, &IDLength)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_id := make([]byte, IDLength)
|
||||
n, err := reader.Read(_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != int(IDLength) {
|
||||
return
|
||||
}
|
||||
ID = string(_id)
|
||||
|
||||
if connType == CONN_SERVER {
|
||||
var addrLength uint16
|
||||
err = binary.Read(reader, binary.LittleEndian, &addrLength)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_addr := make([]byte, addrLength)
|
||||
n, err = reader.Read(_addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != int(addrLength) {
|
||||
return
|
||||
}
|
||||
clientLocalAddr = string(_addr)
|
||||
}
|
||||
}
|
||||
log.Printf("connection from %s , key: %s , id: %s", connTypeStrMap[connType], key, ID)
|
||||
|
||||
switch connType {
|
||||
case CONN_SERVER:
|
||||
addr := clientLocalAddr + "@" + ID
|
||||
var key, ID, clientLocalAddr, serverID string
|
||||
err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID)
|
||||
if err != nil {
|
||||
log.Printf("read error,ERR:%s", err)
|
||||
return
|
||||
}
|
||||
packet := utils.BuildPacketData(ID, clientLocalAddr, serverID)
|
||||
log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID)
|
||||
|
||||
//addr := clientLocalAddr + "@" + ID
|
||||
s.serverConns.Set(ID, ServerConn{
|
||||
Conn: &inConn,
|
||||
ClientLocalAddr: addr,
|
||||
Conn: &inConn,
|
||||
})
|
||||
for {
|
||||
item, ok := s.clientControlConns.Get(key)
|
||||
@ -128,17 +85,26 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
|
||||
continue
|
||||
}
|
||||
(*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3))
|
||||
_, err := (*item.(*net.Conn)).Write([]byte(addr))
|
||||
_, err := (*item.(*net.Conn)).Write(packet)
|
||||
(*item.(*net.Conn)).SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
} else {
|
||||
s.cmServer.Add(serverID, ID, &inConn)
|
||||
break
|
||||
}
|
||||
}
|
||||
case CONN_CLIENT:
|
||||
var key, ID, serverID string
|
||||
err = utils.ReadPacketData(reader, &key, &ID, &serverID)
|
||||
if err != nil {
|
||||
log.Printf("read error,ERR:%s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID)
|
||||
|
||||
serverConnItem, ok := s.serverConns.Get(ID)
|
||||
if !ok {
|
||||
inConn.Close()
|
||||
@ -147,15 +113,24 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
|
||||
}
|
||||
serverConn := serverConnItem.(ServerConn).Conn
|
||||
utils.IoBind(*serverConn, inConn, func(err error) {
|
||||
|
||||
(*serverConn).Close()
|
||||
utils.CloseConn(&inConn)
|
||||
s.serverConns.Remove(ID)
|
||||
s.cmClient.RemoveOne(key, ID)
|
||||
s.cmServer.RemoveOne(serverID, ID)
|
||||
log.Printf("conn %s released", ID)
|
||||
}, func(i int, b bool) {}, 0)
|
||||
s.cmClient.Add(key, ID, &inConn)
|
||||
log.Printf("conn %s created", ID)
|
||||
|
||||
case CONN_CONTROL:
|
||||
case CONN_CLIENT_CONTROL:
|
||||
var key string
|
||||
err = utils.ReadPacketData(reader, &key)
|
||||
if err != nil {
|
||||
log.Printf("read error,ERR:%s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("client control connection, key: %s", key)
|
||||
if s.clientControlConns.Has(key) {
|
||||
item, _ := s.clientControlConns.Get(key)
|
||||
(*item.(*net.Conn)).Close()
|
||||
@ -168,14 +143,59 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
|
||||
_, err = inConn.Read(b)
|
||||
if err != nil {
|
||||
inConn.Close()
|
||||
s.serverConns.Remove(ID)
|
||||
log.Printf("%s control conn from client released", key)
|
||||
s.cmClient.Remove(key)
|
||||
break
|
||||
} else {
|
||||
//log.Printf("%s heartbeat from client", key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
case CONN_SERVER_CONTROL:
|
||||
var serverID string
|
||||
err = utils.ReadPacketData(reader, &serverID)
|
||||
if err != nil {
|
||||
log.Printf("read error,ERR:%s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("server control connection, id: %s", serverID)
|
||||
writeDie := make(chan bool)
|
||||
readDie := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
inConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
|
||||
_, err = inConn.Write([]byte{0x00})
|
||||
inConn.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
log.Printf("control connection write err %s", err)
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
}
|
||||
close(writeDie)
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
signal := make([]byte, 1)
|
||||
inConn.SetReadDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err := inConn.Read(signal)
|
||||
inConn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
log.Printf("control connection read err: %s", err)
|
||||
break
|
||||
} else {
|
||||
// log.Printf("heartbeat from server ,id:%s", ID)
|
||||
}
|
||||
}
|
||||
close(readDie)
|
||||
}()
|
||||
select {
|
||||
case <-readDie:
|
||||
case <-writeDie:
|
||||
}
|
||||
utils.CloseConn(&inConn)
|
||||
s.cmServer.Remove(serverID)
|
||||
log.Printf("server control conn %s released", serverID)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@ -1,25 +1,24 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"proxy/utils"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TunnelClient struct {
|
||||
cfg TunnelClientArgs
|
||||
cm utils.ConnManager
|
||||
}
|
||||
|
||||
func NewTunnelClient() Service {
|
||||
return &TunnelClient{
|
||||
cfg: TunnelClientArgs{},
|
||||
cm: utils.NewConnManager(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,14 +36,20 @@ func (s *TunnelClient) CheckArgs() {
|
||||
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
||||
}
|
||||
func (s *TunnelClient) StopService() {
|
||||
s.cm.RemoveAll()
|
||||
}
|
||||
func (s *TunnelClient) Start(args interface{}) (err error) {
|
||||
s.cfg = args.(TunnelClientArgs)
|
||||
s.CheckArgs()
|
||||
s.InitService()
|
||||
log.Printf("proxy on tunnel client mode")
|
||||
var ctrlConn net.Conn
|
||||
for {
|
||||
ctrlConn, err := s.GetInConn(CONN_CONTROL, "")
|
||||
//close all conn
|
||||
s.cm.Remove(*s.cfg.Key)
|
||||
utils.CloseConn(&ctrlConn)
|
||||
|
||||
ctrlConn, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key)
|
||||
if err != nil {
|
||||
log.Printf("control connection err: %s, retrying...", err)
|
||||
time.Sleep(time.Second * 3)
|
||||
@ -53,6 +58,9 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
if ctrlConn == nil {
|
||||
break
|
||||
}
|
||||
ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
|
||||
_, err = ctrlConn.Write([]byte{0x00})
|
||||
ctrlConn.SetWriteDeadline(time.Time{})
|
||||
@ -65,23 +73,20 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
|
||||
}
|
||||
}()
|
||||
for {
|
||||
signal := make([]byte, 50)
|
||||
n, err := ctrlConn.Read(signal)
|
||||
var ID, clientLocalAddr, serverID string
|
||||
err = utils.ReadPacketData(ctrlConn, &ID, &clientLocalAddr, &serverID)
|
||||
if err != nil {
|
||||
utils.CloseConn(&ctrlConn)
|
||||
log.Printf("read connection signal err: %s, retrying...", err)
|
||||
break
|
||||
}
|
||||
addr := string(signal[:n])
|
||||
log.Printf("signal revecived:%s", addr)
|
||||
protocol := addr[:3]
|
||||
atIndex := strings.Index(addr, "@")
|
||||
ID := addr[atIndex+1:]
|
||||
localAddr := addr[4:atIndex]
|
||||
log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr)
|
||||
protocol := clientLocalAddr[:3]
|
||||
localAddr := clientLocalAddr[4:]
|
||||
if protocol == "udp" {
|
||||
go s.ServeUDP(localAddr, ID)
|
||||
go s.ServeUDP(localAddr, ID, serverID)
|
||||
} else {
|
||||
go s.ServeConn(localAddr, ID)
|
||||
go s.ServeConn(localAddr, ID, serverID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -89,25 +94,13 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
|
||||
func (s *TunnelClient) Clean() {
|
||||
s.StopService()
|
||||
}
|
||||
func (s *TunnelClient) GetInConn(typ uint8, ID string) (outConn net.Conn, err error) {
|
||||
func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, err error) {
|
||||
outConn, err = s.GetConn()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("connection err: %s", err)
|
||||
return
|
||||
}
|
||||
keyBytes := []byte(*s.cfg.Key)
|
||||
keyLength := uint16(len(keyBytes))
|
||||
pkg := new(bytes.Buffer)
|
||||
binary.Write(pkg, binary.LittleEndian, typ)
|
||||
binary.Write(pkg, binary.LittleEndian, keyLength)
|
||||
binary.Write(pkg, binary.LittleEndian, keyBytes)
|
||||
if ID != "" {
|
||||
IDBytes := []byte(ID)
|
||||
IDLength := uint16(len(IDBytes))
|
||||
binary.Write(pkg, binary.LittleEndian, IDLength)
|
||||
binary.Write(pkg, binary.LittleEndian, IDBytes)
|
||||
}
|
||||
_, err = outConn.Write(pkg.Bytes())
|
||||
_, err = outConn.Write(utils.BuildPacket(typ, data...))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("write connection data err: %s ,retrying...", err)
|
||||
utils.CloseConn(&outConn)
|
||||
@ -123,12 +116,13 @@ func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
func (s *TunnelClient) ServeUDP(localAddr, ID string) {
|
||||
func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) {
|
||||
var inConn net.Conn
|
||||
var err error
|
||||
// for {
|
||||
for {
|
||||
inConn, err = s.GetInConn(CONN_CLIENT, ID)
|
||||
s.cm.RemoveOne(*s.cfg.Key, ID)
|
||||
inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
|
||||
if err != nil {
|
||||
utils.CloseConn(&inConn)
|
||||
log.Printf("connection err: %s, retrying...", err)
|
||||
@ -138,13 +132,10 @@ func (s *TunnelClient) ServeUDP(localAddr, ID string) {
|
||||
break
|
||||
}
|
||||
}
|
||||
s.cm.Add(*s.cfg.Key, ID, &inConn)
|
||||
log.Printf("conn %s created", ID)
|
||||
// hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) {
|
||||
// log.Printf("hw err %s", err)
|
||||
// hw.Close()
|
||||
// })
|
||||
|
||||
for {
|
||||
// srcAddr, body, err := utils.ReadUDPPacket(&hw)
|
||||
srcAddr, body, err := utils.ReadUDPPacket(inConn)
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
log.Printf("connection %s released", ID)
|
||||
@ -197,11 +188,11 @@ func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr str
|
||||
}
|
||||
//log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs)
|
||||
}
|
||||
func (s *TunnelClient) ServeConn(localAddr, ID string) {
|
||||
func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) {
|
||||
var inConn, outConn net.Conn
|
||||
var err error
|
||||
for {
|
||||
inConn, err = s.GetInConn(CONN_CLIENT, ID)
|
||||
inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
|
||||
if err != nil {
|
||||
utils.CloseConn(&inConn)
|
||||
log.Printf("connection err: %s, retrying...", err)
|
||||
@ -236,6 +227,8 @@ func (s *TunnelClient) ServeConn(localAddr, ID string) {
|
||||
log.Printf("conn %s released", ID)
|
||||
utils.CloseConn(&inConn)
|
||||
utils.CloseConn(&outConn)
|
||||
s.cm.RemoveOne(*s.cfg.Key, ID)
|
||||
}, func(i int, b bool) {}, 0)
|
||||
s.cm.Add(*s.cfg.Key, ID, &inConn)
|
||||
log.Printf("conn %s created", ID)
|
||||
}
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@ -22,24 +20,33 @@ type TunnelServer struct {
|
||||
}
|
||||
|
||||
type TunnelServerManager struct {
|
||||
cfg TunnelServerArgs
|
||||
udpChn chan UDPItem
|
||||
sc utils.ServerChannel
|
||||
cfg TunnelServerArgs
|
||||
udpChn chan UDPItem
|
||||
sc utils.ServerChannel
|
||||
serverID string
|
||||
cm utils.ConnManager
|
||||
}
|
||||
|
||||
func NewTunnelServerManager() Service {
|
||||
return &TunnelServerManager{
|
||||
cfg: TunnelServerArgs{},
|
||||
udpChn: make(chan UDPItem, 50000),
|
||||
cfg: TunnelServerArgs{},
|
||||
udpChn: make(chan UDPItem, 50000),
|
||||
serverID: utils.Uniqueid(),
|
||||
cm: utils.NewConnManager(),
|
||||
}
|
||||
}
|
||||
func (s *TunnelServerManager) Start(args interface{}) (err error) {
|
||||
s.cfg = args.(TunnelServerArgs)
|
||||
s.CheckArgs()
|
||||
if *s.cfg.Parent != "" {
|
||||
log.Printf("use tls parent %s", *s.cfg.Parent)
|
||||
} else {
|
||||
log.Fatalf("parent required")
|
||||
}
|
||||
|
||||
s.InitService()
|
||||
|
||||
log.Printf("server id: %s", s.serverID)
|
||||
//log.Printf("route:%v", *s.cfg.Route)
|
||||
for _, _info := range *s.cfg.Route {
|
||||
IsUDP := *s.cfg.IsUDP
|
||||
@ -71,6 +78,7 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) {
|
||||
Remote: &remote,
|
||||
Key: &KEY,
|
||||
Timeout: s.cfg.Timeout,
|
||||
Mgr: s,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@ -80,7 +88,95 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) {
|
||||
return
|
||||
}
|
||||
func (s *TunnelServerManager) Clean() {
|
||||
|
||||
s.StopService()
|
||||
}
|
||||
func (s *TunnelServerManager) StopService() {
|
||||
s.cm.RemoveAll()
|
||||
}
|
||||
func (s *TunnelServerManager) CheckArgs() {
|
||||
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
|
||||
log.Fatalf("cert and key file required")
|
||||
}
|
||||
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
||||
}
|
||||
func (s *TunnelServerManager) InitService() {
|
||||
s.InitControlDeamon()
|
||||
}
|
||||
func (s *TunnelServerManager) InitControlDeamon() {
|
||||
go func() {
|
||||
var ctrlConn net.Conn
|
||||
var ID string
|
||||
for {
|
||||
//close all connection
|
||||
s.cm.Remove(ID)
|
||||
utils.CloseConn(&ctrlConn)
|
||||
ctrlConn, ID, err := s.GetOutConn(CONN_SERVER_CONTROL)
|
||||
if err != nil {
|
||||
log.Printf("control connection err: %s, retrying...", err)
|
||||
time.Sleep(time.Second * 3)
|
||||
utils.CloseConn(&ctrlConn)
|
||||
continue
|
||||
}
|
||||
log.Printf("control connection created,id:%s", ID)
|
||||
writeDie := make(chan bool)
|
||||
readDie := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
|
||||
_, err = ctrlConn.Write([]byte{0x00})
|
||||
ctrlConn.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
log.Printf("control connection write err %s", err)
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
}
|
||||
close(writeDie)
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
signal := make([]byte, 1)
|
||||
ctrlConn.SetReadDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err := ctrlConn.Read(signal)
|
||||
ctrlConn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
log.Printf("control connection read err: %s", err)
|
||||
break
|
||||
} else {
|
||||
// log.Printf("heartbeat from bridge")
|
||||
}
|
||||
}
|
||||
close(readDie)
|
||||
}()
|
||||
select {
|
||||
case <-readDie:
|
||||
case <-writeDie:
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
func (s *TunnelServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) {
|
||||
outConn, err = s.GetConn()
|
||||
if err != nil {
|
||||
log.Printf("connection err: %s", err)
|
||||
return
|
||||
}
|
||||
ID = s.serverID
|
||||
_, err = outConn.Write(utils.BuildPacket(typ, s.serverID))
|
||||
if err != nil {
|
||||
log.Printf("write connection data err: %s ,retrying...", err)
|
||||
utils.CloseConn(&outConn)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
func (s *TunnelServerManager) GetConn() (conn net.Conn, err error) {
|
||||
var _conn tls.Conn
|
||||
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
|
||||
if err == nil {
|
||||
conn = net.Conn(&_conn)
|
||||
}
|
||||
return
|
||||
}
|
||||
func NewTunnelServer() Service {
|
||||
return &TunnelServer{
|
||||
@ -102,13 +198,8 @@ func (s *TunnelServer) CheckArgs() {
|
||||
if *s.cfg.Remote == "" {
|
||||
log.Fatalf("remote required")
|
||||
}
|
||||
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
|
||||
log.Fatalf("cert and key file required")
|
||||
}
|
||||
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
|
||||
}
|
||||
func (s *TunnelServer) StopService() {
|
||||
}
|
||||
|
||||
func (s *TunnelServer) Start(args interface{}) (err error) {
|
||||
s.cfg = args.(TunnelServerArgs)
|
||||
s.CheckArgs()
|
||||
@ -138,7 +229,7 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
|
||||
var outConn net.Conn
|
||||
var ID string
|
||||
for {
|
||||
outConn, ID, err = s.GetOutConn("")
|
||||
outConn, ID, err = s.GetOutConn(CONN_SERVER)
|
||||
if err != nil {
|
||||
utils.CloseConn(&outConn)
|
||||
log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err)
|
||||
@ -148,17 +239,14 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
|
||||
break
|
||||
}
|
||||
}
|
||||
// hb := utils.NewHeartbeatReadWriter(&outConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
|
||||
// log.Printf("%s conn %s to bridge released", *s.cfg.Key, ID)
|
||||
// hb.Close()
|
||||
// })
|
||||
// utils.IoBind(inConn, &hb, func(err error) {
|
||||
utils.IoBind(inConn, outConn, func(err error) {
|
||||
utils.CloseConn(&outConn)
|
||||
utils.CloseConn(&inConn)
|
||||
s.cfg.Mgr.cm.RemoveOne(s.cfg.Mgr.serverID, ID)
|
||||
log.Printf("%s conn %s released", *s.cfg.Key, ID)
|
||||
}, func(i int, b bool) {}, 0)
|
||||
|
||||
//add conn
|
||||
s.cfg.Mgr.cm.Add(s.cfg.Mgr.serverID, ID, &inConn)
|
||||
log.Printf("%s conn %s created", *s.cfg.Key, ID)
|
||||
})
|
||||
if err != nil {
|
||||
@ -169,37 +257,20 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
|
||||
return
|
||||
}
|
||||
func (s *TunnelServer) Clean() {
|
||||
s.StopService()
|
||||
|
||||
}
|
||||
func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, ID string, err error) {
|
||||
func (s *TunnelServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) {
|
||||
outConn, err = s.GetConn()
|
||||
if err != nil {
|
||||
log.Printf("connection err: %s", err)
|
||||
return
|
||||
}
|
||||
keyBytes := []byte(*s.cfg.Key)
|
||||
keyLength := uint16(len(keyBytes))
|
||||
ID = utils.Uniqueid()
|
||||
IDBytes := []byte(ID)
|
||||
if id != "" {
|
||||
ID = id
|
||||
IDBytes = []byte(id)
|
||||
}
|
||||
IDLength := uint16(len(IDBytes))
|
||||
remoteAddr := []byte("tcp:" + *s.cfg.Remote)
|
||||
remoteAddr := "tcp:" + *s.cfg.Remote
|
||||
if *s.cfg.IsUDP {
|
||||
remoteAddr = []byte("udp:" + *s.cfg.Remote)
|
||||
remoteAddr = "udp:" + *s.cfg.Remote
|
||||
}
|
||||
remoteAddrLength := uint16(len(remoteAddr))
|
||||
pkg := new(bytes.Buffer)
|
||||
binary.Write(pkg, binary.LittleEndian, CONN_SERVER)
|
||||
binary.Write(pkg, binary.LittleEndian, keyLength)
|
||||
binary.Write(pkg, binary.LittleEndian, keyBytes)
|
||||
binary.Write(pkg, binary.LittleEndian, IDLength)
|
||||
binary.Write(pkg, binary.LittleEndian, IDBytes)
|
||||
binary.Write(pkg, binary.LittleEndian, remoteAddrLength)
|
||||
binary.Write(pkg, binary.LittleEndian, remoteAddr)
|
||||
_, err = outConn.Write(pkg.Bytes())
|
||||
ID = utils.Uniqueid()
|
||||
_, err = outConn.Write(utils.BuildPacket(typ, *s.cfg.Key, ID, remoteAddr, s.cfg.Mgr.serverID))
|
||||
if err != nil {
|
||||
log.Printf("write connection data err: %s ,retrying...", err)
|
||||
utils.CloseConn(&outConn)
|
||||
@ -232,7 +303,7 @@ func (s *TunnelServer) UDPConnDeamon() {
|
||||
RETRY:
|
||||
if outConn == nil {
|
||||
for {
|
||||
outConn, ID, err = s.GetOutConn("")
|
||||
outConn, ID, err = s.GetOutConn(CONN_SERVER)
|
||||
if err != nil {
|
||||
// cmdChn <- true
|
||||
outConn = nil
|
||||
|
||||
@ -315,6 +315,68 @@ func Uniqueid() string {
|
||||
s := fmt.Sprintf("%d", src.Int63())
|
||||
return s[len(s)-5:len(s)-1] + fmt.Sprintf("%d", uint64(time.Now().UnixNano()))[8:]
|
||||
}
|
||||
func ReadData(r io.Reader) (data string, err error) {
|
||||
var len uint16
|
||||
err = binary.Read(r, binary.LittleEndian, &len)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var n int
|
||||
_data := make([]byte, len)
|
||||
n, err = r.Read(_data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != int(len) {
|
||||
err = fmt.Errorf("error data len")
|
||||
return
|
||||
}
|
||||
data = string(_data)
|
||||
return
|
||||
}
|
||||
func ReadPacketData(r io.Reader, data ...*string) (err error) {
|
||||
for _, d := range data {
|
||||
*d, err = ReadData(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func ReadPacket(r io.Reader, typ *uint8, data ...*string) (err error) {
|
||||
var connType uint8
|
||||
err = binary.Read(r, binary.LittleEndian, &connType)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
*typ = connType
|
||||
for _, d := range data {
|
||||
*d, err = ReadData(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func BuildPacket(typ uint8, data ...string) []byte {
|
||||
pkg := new(bytes.Buffer)
|
||||
binary.Write(pkg, binary.LittleEndian, typ)
|
||||
for _, d := range data {
|
||||
bytes := []byte(d)
|
||||
binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
|
||||
binary.Write(pkg, binary.LittleEndian, bytes)
|
||||
}
|
||||
return pkg.Bytes()
|
||||
}
|
||||
func BuildPacketData(data ...string) []byte {
|
||||
pkg := new(bytes.Buffer)
|
||||
for _, d := range data {
|
||||
bytes := []byte(d)
|
||||
binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
|
||||
binary.Write(pkg, binary.LittleEndian, bytes)
|
||||
}
|
||||
return pkg.Bytes()
|
||||
}
|
||||
func SubStr(str string, start, end int) string {
|
||||
if len(str) == 0 {
|
||||
return ""
|
||||
|
||||
@ -617,3 +617,64 @@ func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type ConnManager struct {
|
||||
pool ConcurrentMap
|
||||
l *sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnManager() ConnManager {
|
||||
cm := ConnManager{
|
||||
pool: NewConcurrentMap(),
|
||||
l: &sync.Mutex{},
|
||||
}
|
||||
return cm
|
||||
}
|
||||
func (cm *ConnManager) Add(key, ID string, conn *net.Conn) {
|
||||
cm.pool.Upsert(key, nil, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} {
|
||||
var conns ConcurrentMap
|
||||
if !exist {
|
||||
conns = NewConcurrentMap()
|
||||
} else {
|
||||
conns = valueInMap.(ConcurrentMap)
|
||||
}
|
||||
if conns.Has(ID) {
|
||||
v, _ := conns.Get(ID)
|
||||
(*v.(*net.Conn)).Close()
|
||||
}
|
||||
conns.Set(ID, conn)
|
||||
log.Printf("%s conn added", key)
|
||||
return conns
|
||||
})
|
||||
}
|
||||
func (cm *ConnManager) Remove(key string) {
|
||||
var conns ConcurrentMap
|
||||
if v, ok := cm.pool.Get(key); ok {
|
||||
conns = v.(ConcurrentMap)
|
||||
conns.IterCb(func(key string, v interface{}) {
|
||||
CloseConn(v.(*net.Conn))
|
||||
})
|
||||
log.Printf("%s conns closed", key)
|
||||
}
|
||||
cm.pool.Remove(key)
|
||||
}
|
||||
func (cm *ConnManager) RemoveOne(key string, ID string) {
|
||||
defer cm.l.Unlock()
|
||||
cm.l.Lock()
|
||||
var conns ConcurrentMap
|
||||
if v, ok := cm.pool.Get(key); ok {
|
||||
conns = v.(ConcurrentMap)
|
||||
if conns.Has(ID) {
|
||||
v, _ := conns.Get(ID)
|
||||
(*v.(*net.Conn)).Close()
|
||||
conns.Remove(ID)
|
||||
cm.pool.Set(key, conns)
|
||||
log.Printf("%s %s conn closed", key, ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
func (cm *ConnManager) RemoveAll() {
|
||||
for _, k := range cm.pool.Keys() {
|
||||
cm.Remove(k)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user