完善内网穿透心跳机制

Signed-off-by: arraykeys@gmail.com <arraykeys@gmail.com>
This commit is contained in:
arraykeys@gmail.com
2017-10-23 16:28:10 +08:00
parent 96cd7a2b63
commit 078acaa0e8
6 changed files with 365 additions and 156 deletions

View File

@ -10,9 +10,10 @@ const (
TYPE_UDP = "udp"
TYPE_HTTP = "http"
TYPE_TLS = "tls"
CONN_CONTROL = uint8(1)
CONN_SERVER = uint8(2)
CONN_CLIENT = uint8(3)
CONN_CLIENT_CONTROL = uint8(1)
CONN_SERVER_CONTROL = uint8(2)
CONN_SERVER = uint8(3)
CONN_CLIENT = uint8(4)
)
type TunnelServerArgs struct {
@ -27,6 +28,7 @@ type TunnelServerArgs struct {
Remote *string
Timeout *int
Route *[]string
Mgr *TunnelServerManager
}
type TunnelClientArgs struct {
Parent *string

View File

@ -2,7 +2,6 @@ package services
import (
"bufio"
"encoding/binary"
"log"
"net"
"proxy/utils"
@ -11,13 +10,15 @@ import (
)
type ServerConn struct {
ClientLocalAddr string //tcp:2.2.22:333@ID
//ClientLocalAddr string //tcp:2.2.22:333@ID
Conn *net.Conn
}
type TunnelBridge struct {
cfg TunnelBridgeArgs
serverConns utils.ConcurrentMap
clientControlConns utils.ConcurrentMap
cmServer utils.ConnManager
cmClient utils.ConnManager
}
func NewTunnelBridge() Service {
@ -25,6 +26,8 @@ func NewTunnelBridge() Service {
cfg: TunnelBridgeArgs{},
serverConns: utils.NewConcurrentMap(),
clientControlConns: utils.NewConcurrentMap(),
cmServer: utils.NewConnManager(),
cmClient: utils.NewConnManager(),
}
}
@ -52,73 +55,27 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
//log.Printf("connection from %s ", inConn.RemoteAddr())
reader := bufio.NewReader(inConn)
var err error
var connType uint8
err = binary.Read(reader, binary.LittleEndian, &connType)
err = utils.ReadPacket(reader, &connType)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("read error,ERR:%s", err)
return
}
//log.Printf("conn type %d", connType)
var key, clientLocalAddr, ID string
var connTypeStrMap = map[uint8]string{CONN_SERVER: "server", CONN_CLIENT: "client", CONN_CONTROL: "client"}
var keyLength uint16
err = binary.Read(reader, binary.LittleEndian, &keyLength)
if err != nil {
return
}
_key := make([]byte, keyLength)
n, err := reader.Read(_key)
if err != nil {
return
}
if n != int(keyLength) {
return
}
key = string(_key)
if connType != CONN_CONTROL {
var IDLength uint16
err = binary.Read(reader, binary.LittleEndian, &IDLength)
if err != nil {
return
}
_id := make([]byte, IDLength)
n, err := reader.Read(_id)
if err != nil {
return
}
if n != int(IDLength) {
return
}
ID = string(_id)
if connType == CONN_SERVER {
var addrLength uint16
err = binary.Read(reader, binary.LittleEndian, &addrLength)
if err != nil {
return
}
_addr := make([]byte, addrLength)
n, err = reader.Read(_addr)
if err != nil {
return
}
if n != int(addrLength) {
return
}
clientLocalAddr = string(_addr)
}
}
log.Printf("connection from %s , key: %s , id: %s", connTypeStrMap[connType], key, ID)
switch connType {
case CONN_SERVER:
addr := clientLocalAddr + "@" + ID
var key, ID, clientLocalAddr, serverID string
err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
packet := utils.BuildPacketData(ID, clientLocalAddr, serverID)
log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID)
//addr := clientLocalAddr + "@" + ID
s.serverConns.Set(ID, ServerConn{
Conn: &inConn,
ClientLocalAddr: addr,
})
for {
item, ok := s.clientControlConns.Get(key)
@ -128,17 +85,26 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
continue
}
(*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err := (*item.(*net.Conn)).Write([]byte(addr))
_, err := (*item.(*net.Conn)).Write(packet)
(*item.(*net.Conn)).SetWriteDeadline(time.Time{})
if err != nil {
log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
time.Sleep(time.Second * 3)
continue
} else {
s.cmServer.Add(serverID, ID, &inConn)
break
}
}
case CONN_CLIENT:
var key, ID, serverID string
err = utils.ReadPacketData(reader, &key, &ID, &serverID)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID)
serverConnItem, ok := s.serverConns.Get(ID)
if !ok {
inConn.Close()
@ -147,15 +113,24 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
}
serverConn := serverConnItem.(ServerConn).Conn
utils.IoBind(*serverConn, inConn, func(err error) {
(*serverConn).Close()
utils.CloseConn(&inConn)
s.serverConns.Remove(ID)
s.cmClient.RemoveOne(key, ID)
s.cmServer.RemoveOne(serverID, ID)
log.Printf("conn %s released", ID)
}, func(i int, b bool) {}, 0)
s.cmClient.Add(key, ID, &inConn)
log.Printf("conn %s created", ID)
case CONN_CONTROL:
case CONN_CLIENT_CONTROL:
var key string
err = utils.ReadPacketData(reader, &key)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
log.Printf("client control connection, key: %s", key)
if s.clientControlConns.Has(key) {
item, _ := s.clientControlConns.Get(key)
(*item.(*net.Conn)).Close()
@ -168,14 +143,59 @@ func (s *TunnelBridge) Start(args interface{}) (err error) {
_, err = inConn.Read(b)
if err != nil {
inConn.Close()
s.serverConns.Remove(ID)
log.Printf("%s control conn from client released", key)
s.cmClient.Remove(key)
break
} else {
//log.Printf("%s heartbeat from client", key)
}
}
}()
case CONN_SERVER_CONTROL:
var serverID string
err = utils.ReadPacketData(reader, &serverID)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
log.Printf("server control connection, id: %s", serverID)
writeDie := make(chan bool)
readDie := make(chan bool)
go func() {
for {
inConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err = inConn.Write([]byte{0x00})
inConn.SetWriteDeadline(time.Time{})
if err != nil {
log.Printf("control connection write err %s", err)
break
}
time.Sleep(time.Second * 3)
}
close(writeDie)
}()
go func() {
for {
signal := make([]byte, 1)
inConn.SetReadDeadline(time.Now().Add(time.Second * 10))
_, err := inConn.Read(signal)
inConn.SetReadDeadline(time.Time{})
if err != nil {
log.Printf("control connection read err: %s", err)
break
} else {
// log.Printf("heartbeat from server ,id:%s", ID)
}
}
close(readDie)
}()
select {
case <-readDie:
case <-writeDie:
}
utils.CloseConn(&inConn)
s.cmServer.Remove(serverID)
log.Printf("server control conn %s released", serverID)
}
})
if err != nil {

View File

@ -1,25 +1,24 @@
package services
import (
"bytes"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"proxy/utils"
"strings"
"time"
)
type TunnelClient struct {
cfg TunnelClientArgs
cm utils.ConnManager
}
func NewTunnelClient() Service {
return &TunnelClient{
cfg: TunnelClientArgs{},
cm: utils.NewConnManager(),
}
}
@ -37,14 +36,20 @@ func (s *TunnelClient) CheckArgs() {
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
}
func (s *TunnelClient) StopService() {
s.cm.RemoveAll()
}
func (s *TunnelClient) Start(args interface{}) (err error) {
s.cfg = args.(TunnelClientArgs)
s.CheckArgs()
s.InitService()
log.Printf("proxy on tunnel client mode")
var ctrlConn net.Conn
for {
ctrlConn, err := s.GetInConn(CONN_CONTROL, "")
//close all conn
s.cm.Remove(*s.cfg.Key)
utils.CloseConn(&ctrlConn)
ctrlConn, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key)
if err != nil {
log.Printf("control connection err: %s, retrying...", err)
time.Sleep(time.Second * 3)
@ -53,6 +58,9 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
}
go func() {
for {
if ctrlConn == nil {
break
}
ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err = ctrlConn.Write([]byte{0x00})
ctrlConn.SetWriteDeadline(time.Time{})
@ -65,23 +73,20 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
}
}()
for {
signal := make([]byte, 50)
n, err := ctrlConn.Read(signal)
var ID, clientLocalAddr, serverID string
err = utils.ReadPacketData(ctrlConn, &ID, &clientLocalAddr, &serverID)
if err != nil {
utils.CloseConn(&ctrlConn)
log.Printf("read connection signal err: %s, retrying...", err)
break
}
addr := string(signal[:n])
log.Printf("signal revecived:%s", addr)
protocol := addr[:3]
atIndex := strings.Index(addr, "@")
ID := addr[atIndex+1:]
localAddr := addr[4:atIndex]
log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr)
protocol := clientLocalAddr[:3]
localAddr := clientLocalAddr[4:]
if protocol == "udp" {
go s.ServeUDP(localAddr, ID)
go s.ServeUDP(localAddr, ID, serverID)
} else {
go s.ServeConn(localAddr, ID)
go s.ServeConn(localAddr, ID, serverID)
}
}
}
@ -89,25 +94,13 @@ func (s *TunnelClient) Start(args interface{}) (err error) {
func (s *TunnelClient) Clean() {
s.StopService()
}
func (s *TunnelClient) GetInConn(typ uint8, ID string) (outConn net.Conn, err error) {
func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, err error) {
outConn, err = s.GetConn()
if err != nil {
err = fmt.Errorf("connection err: %s", err)
return
}
keyBytes := []byte(*s.cfg.Key)
keyLength := uint16(len(keyBytes))
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, typ)
binary.Write(pkg, binary.LittleEndian, keyLength)
binary.Write(pkg, binary.LittleEndian, keyBytes)
if ID != "" {
IDBytes := []byte(ID)
IDLength := uint16(len(IDBytes))
binary.Write(pkg, binary.LittleEndian, IDLength)
binary.Write(pkg, binary.LittleEndian, IDBytes)
}
_, err = outConn.Write(pkg.Bytes())
_, err = outConn.Write(utils.BuildPacket(typ, data...))
if err != nil {
err = fmt.Errorf("write connection data err: %s ,retrying...", err)
utils.CloseConn(&outConn)
@ -123,12 +116,13 @@ func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
}
return
}
func (s *TunnelClient) ServeUDP(localAddr, ID string) {
func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) {
var inConn net.Conn
var err error
// for {
for {
inConn, err = s.GetInConn(CONN_CLIENT, ID)
s.cm.RemoveOne(*s.cfg.Key, ID)
inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("connection err: %s, retrying...", err)
@ -138,13 +132,10 @@ func (s *TunnelClient) ServeUDP(localAddr, ID string) {
break
}
}
s.cm.Add(*s.cfg.Key, ID, &inConn)
log.Printf("conn %s created", ID)
// hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) {
// log.Printf("hw err %s", err)
// hw.Close()
// })
for {
// srcAddr, body, err := utils.ReadUDPPacket(&hw)
srcAddr, body, err := utils.ReadUDPPacket(inConn)
if err == io.EOF || err == io.ErrUnexpectedEOF {
log.Printf("connection %s released", ID)
@ -197,11 +188,11 @@ func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr str
}
//log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs)
}
func (s *TunnelClient) ServeConn(localAddr, ID string) {
func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) {
var inConn, outConn net.Conn
var err error
for {
inConn, err = s.GetInConn(CONN_CLIENT, ID)
inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("connection err: %s, retrying...", err)
@ -236,6 +227,8 @@ func (s *TunnelClient) ServeConn(localAddr, ID string) {
log.Printf("conn %s released", ID)
utils.CloseConn(&inConn)
utils.CloseConn(&outConn)
s.cm.RemoveOne(*s.cfg.Key, ID)
}, func(i int, b bool) {}, 0)
s.cm.Add(*s.cfg.Key, ID, &inConn)
log.Printf("conn %s created", ID)
}

View File

@ -1,9 +1,7 @@
package services
import (
"bytes"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"log"
@ -25,21 +23,30 @@ type TunnelServerManager struct {
cfg TunnelServerArgs
udpChn chan UDPItem
sc utils.ServerChannel
serverID string
cm utils.ConnManager
}
func NewTunnelServerManager() Service {
return &TunnelServerManager{
cfg: TunnelServerArgs{},
udpChn: make(chan UDPItem, 50000),
serverID: utils.Uniqueid(),
cm: utils.NewConnManager(),
}
}
func (s *TunnelServerManager) Start(args interface{}) (err error) {
s.cfg = args.(TunnelServerArgs)
s.CheckArgs()
if *s.cfg.Parent != "" {
log.Printf("use tls parent %s", *s.cfg.Parent)
} else {
log.Fatalf("parent required")
}
s.InitService()
log.Printf("server id: %s", s.serverID)
//log.Printf("route:%v", *s.cfg.Route)
for _, _info := range *s.cfg.Route {
IsUDP := *s.cfg.IsUDP
@ -71,6 +78,7 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) {
Remote: &remote,
Key: &KEY,
Timeout: s.cfg.Timeout,
Mgr: s,
})
if err != nil {
@ -80,7 +88,95 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) {
return
}
func (s *TunnelServerManager) Clean() {
s.StopService()
}
func (s *TunnelServerManager) StopService() {
s.cm.RemoveAll()
}
func (s *TunnelServerManager) CheckArgs() {
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
log.Fatalf("cert and key file required")
}
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
}
func (s *TunnelServerManager) InitService() {
s.InitControlDeamon()
}
func (s *TunnelServerManager) InitControlDeamon() {
go func() {
var ctrlConn net.Conn
var ID string
for {
//close all connection
s.cm.Remove(ID)
utils.CloseConn(&ctrlConn)
ctrlConn, ID, err := s.GetOutConn(CONN_SERVER_CONTROL)
if err != nil {
log.Printf("control connection err: %s, retrying...", err)
time.Sleep(time.Second * 3)
utils.CloseConn(&ctrlConn)
continue
}
log.Printf("control connection created,id:%s", ID)
writeDie := make(chan bool)
readDie := make(chan bool)
go func() {
for {
ctrlConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err = ctrlConn.Write([]byte{0x00})
ctrlConn.SetWriteDeadline(time.Time{})
if err != nil {
log.Printf("control connection write err %s", err)
break
}
time.Sleep(time.Second * 3)
}
close(writeDie)
}()
go func() {
for {
signal := make([]byte, 1)
ctrlConn.SetReadDeadline(time.Now().Add(time.Second * 10))
_, err := ctrlConn.Read(signal)
ctrlConn.SetReadDeadline(time.Time{})
if err != nil {
log.Printf("control connection read err: %s", err)
break
} else {
// log.Printf("heartbeat from bridge")
}
}
close(readDie)
}()
select {
case <-readDie:
case <-writeDie:
}
}
}()
}
func (s *TunnelServerManager) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) {
outConn, err = s.GetConn()
if err != nil {
log.Printf("connection err: %s", err)
return
}
ID = s.serverID
_, err = outConn.Write(utils.BuildPacket(typ, s.serverID))
if err != nil {
log.Printf("write connection data err: %s ,retrying...", err)
utils.CloseConn(&outConn)
return
}
return
}
func (s *TunnelServerManager) GetConn() (conn net.Conn, err error) {
var _conn tls.Conn
_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes)
if err == nil {
conn = net.Conn(&_conn)
}
return
}
func NewTunnelServer() Service {
return &TunnelServer{
@ -102,13 +198,8 @@ func (s *TunnelServer) CheckArgs() {
if *s.cfg.Remote == "" {
log.Fatalf("remote required")
}
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
log.Fatalf("cert and key file required")
}
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
}
func (s *TunnelServer) StopService() {
}
func (s *TunnelServer) Start(args interface{}) (err error) {
s.cfg = args.(TunnelServerArgs)
s.CheckArgs()
@ -138,7 +229,7 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
var outConn net.Conn
var ID string
for {
outConn, ID, err = s.GetOutConn("")
outConn, ID, err = s.GetOutConn(CONN_SERVER)
if err != nil {
utils.CloseConn(&outConn)
log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err)
@ -148,17 +239,14 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
break
}
}
// hb := utils.NewHeartbeatReadWriter(&outConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
// log.Printf("%s conn %s to bridge released", *s.cfg.Key, ID)
// hb.Close()
// })
// utils.IoBind(inConn, &hb, func(err error) {
utils.IoBind(inConn, outConn, func(err error) {
utils.CloseConn(&outConn)
utils.CloseConn(&inConn)
s.cfg.Mgr.cm.RemoveOne(s.cfg.Mgr.serverID, ID)
log.Printf("%s conn %s released", *s.cfg.Key, ID)
}, func(i int, b bool) {}, 0)
//add conn
s.cfg.Mgr.cm.Add(s.cfg.Mgr.serverID, ID, &inConn)
log.Printf("%s conn %s created", *s.cfg.Key, ID)
})
if err != nil {
@ -169,37 +257,20 @@ func (s *TunnelServer) Start(args interface{}) (err error) {
return
}
func (s *TunnelServer) Clean() {
s.StopService()
}
func (s *TunnelServer) GetOutConn(id string) (outConn net.Conn, ID string, err error) {
func (s *TunnelServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) {
outConn, err = s.GetConn()
if err != nil {
log.Printf("connection err: %s", err)
return
}
keyBytes := []byte(*s.cfg.Key)
keyLength := uint16(len(keyBytes))
ID = utils.Uniqueid()
IDBytes := []byte(ID)
if id != "" {
ID = id
IDBytes = []byte(id)
}
IDLength := uint16(len(IDBytes))
remoteAddr := []byte("tcp:" + *s.cfg.Remote)
remoteAddr := "tcp:" + *s.cfg.Remote
if *s.cfg.IsUDP {
remoteAddr = []byte("udp:" + *s.cfg.Remote)
remoteAddr = "udp:" + *s.cfg.Remote
}
remoteAddrLength := uint16(len(remoteAddr))
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, CONN_SERVER)
binary.Write(pkg, binary.LittleEndian, keyLength)
binary.Write(pkg, binary.LittleEndian, keyBytes)
binary.Write(pkg, binary.LittleEndian, IDLength)
binary.Write(pkg, binary.LittleEndian, IDBytes)
binary.Write(pkg, binary.LittleEndian, remoteAddrLength)
binary.Write(pkg, binary.LittleEndian, remoteAddr)
_, err = outConn.Write(pkg.Bytes())
ID = utils.Uniqueid()
_, err = outConn.Write(utils.BuildPacket(typ, *s.cfg.Key, ID, remoteAddr, s.cfg.Mgr.serverID))
if err != nil {
log.Printf("write connection data err: %s ,retrying...", err)
utils.CloseConn(&outConn)
@ -232,7 +303,7 @@ func (s *TunnelServer) UDPConnDeamon() {
RETRY:
if outConn == nil {
for {
outConn, ID, err = s.GetOutConn("")
outConn, ID, err = s.GetOutConn(CONN_SERVER)
if err != nil {
// cmdChn <- true
outConn = nil

View File

@ -315,6 +315,68 @@ func Uniqueid() string {
s := fmt.Sprintf("%d", src.Int63())
return s[len(s)-5:len(s)-1] + fmt.Sprintf("%d", uint64(time.Now().UnixNano()))[8:]
}
func ReadData(r io.Reader) (data string, err error) {
var len uint16
err = binary.Read(r, binary.LittleEndian, &len)
if err != nil {
return
}
var n int
_data := make([]byte, len)
n, err = r.Read(_data)
if err != nil {
return
}
if n != int(len) {
err = fmt.Errorf("error data len")
return
}
data = string(_data)
return
}
func ReadPacketData(r io.Reader, data ...*string) (err error) {
for _, d := range data {
*d, err = ReadData(r)
if err != nil {
return
}
}
return
}
func ReadPacket(r io.Reader, typ *uint8, data ...*string) (err error) {
var connType uint8
err = binary.Read(r, binary.LittleEndian, &connType)
if err != nil {
return
}
*typ = connType
for _, d := range data {
*d, err = ReadData(r)
if err != nil {
return
}
}
return
}
func BuildPacket(typ uint8, data ...string) []byte {
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, typ)
for _, d := range data {
bytes := []byte(d)
binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
binary.Write(pkg, binary.LittleEndian, bytes)
}
return pkg.Bytes()
}
func BuildPacketData(data ...string) []byte {
pkg := new(bytes.Buffer)
for _, d := range data {
bytes := []byte(d)
binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
binary.Write(pkg, binary.LittleEndian, bytes)
}
return pkg.Bytes()
}
func SubStr(str string, start, end int) string {
if len(str) == 0 {
return ""

View File

@ -617,3 +617,64 @@ func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) {
}
return
}
type ConnManager struct {
pool ConcurrentMap
l *sync.Mutex
}
func NewConnManager() ConnManager {
cm := ConnManager{
pool: NewConcurrentMap(),
l: &sync.Mutex{},
}
return cm
}
func (cm *ConnManager) Add(key, ID string, conn *net.Conn) {
cm.pool.Upsert(key, nil, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} {
var conns ConcurrentMap
if !exist {
conns = NewConcurrentMap()
} else {
conns = valueInMap.(ConcurrentMap)
}
if conns.Has(ID) {
v, _ := conns.Get(ID)
(*v.(*net.Conn)).Close()
}
conns.Set(ID, conn)
log.Printf("%s conn added", key)
return conns
})
}
func (cm *ConnManager) Remove(key string) {
var conns ConcurrentMap
if v, ok := cm.pool.Get(key); ok {
conns = v.(ConcurrentMap)
conns.IterCb(func(key string, v interface{}) {
CloseConn(v.(*net.Conn))
})
log.Printf("%s conns closed", key)
}
cm.pool.Remove(key)
}
func (cm *ConnManager) RemoveOne(key string, ID string) {
defer cm.l.Unlock()
cm.l.Lock()
var conns ConcurrentMap
if v, ok := cm.pool.Get(key); ok {
conns = v.(ConcurrentMap)
if conns.Has(ID) {
v, _ := conns.Get(ID)
(*v.(*net.Conn)).Close()
conns.Remove(ID)
cm.pool.Set(key, conns)
log.Printf("%s %s conn closed", key, ID)
}
}
}
func (cm *ConnManager) RemoveAll() {
for _, k := range cm.pool.Keys() {
cm.Remove(k)
}
}