goproxy/services/tunnel_bridge.go
2017-10-07 17:31:31 +08:00

192 lines
4.8 KiB
Go

package services
import (
"bufio"
"encoding/binary"
"log"
"net"
"proxy/utils"
"strconv"
"time"
)
type ServerConn struct {
ClientLocalAddr string //tcp:2.2.22:333@ID
Conn *net.Conn
//Conn *utils.HeartbeatReadWriter
}
type TunnelBridge struct {
cfg TunnelBridgeArgs
serverConns utils.ConcurrentMap
clientControlConns utils.ConcurrentMap
}
func NewTunnelBridge() Service {
return &TunnelBridge{
cfg: TunnelBridgeArgs{},
serverConns: utils.NewConcurrentMap(),
clientControlConns: utils.NewConcurrentMap(),
}
}
func (s *TunnelBridge) InitService() {
}
func (s *TunnelBridge) Check() {
if s.cfg.CertBytes == nil || s.cfg.KeyBytes == nil {
log.Fatalf("cert and key file required")
}
}
func (s *TunnelBridge) StopService() {
}
func (s *TunnelBridge) Start(args interface{}) (err error) {
s.cfg = args.(TunnelBridgeArgs)
s.Check()
s.InitService()
host, port, _ := net.SplitHostPort(*s.cfg.Local)
p, _ := strconv.Atoi(port)
sc := utils.NewServerChannel(host, p)
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) {
//log.Printf("connection from %s ", inConn.RemoteAddr())
reader := bufio.NewReader(inConn)
var connType uint8
err = binary.Read(reader, binary.LittleEndian, &connType)
if err != nil {
utils.CloseConn(&inConn)
return
}
//log.Printf("conn type %d", connType)
var key, clientLocalAddr, ID string
var connTypeStrMap = map[uint8]string{CONN_SERVER: "server", CONN_CLIENT: "client", CONN_CONTROL: "client"}
var keyLength uint16
err = binary.Read(reader, binary.LittleEndian, &keyLength)
if err != nil {
return
}
_key := make([]byte, keyLength)
n, err := reader.Read(_key)
if err != nil {
return
}
if n != int(keyLength) {
return
}
key = string(_key)
//log.Printf("conn key %s", key)
if connType != CONN_CONTROL {
var IDLength uint16
err = binary.Read(reader, binary.LittleEndian, &IDLength)
if err != nil {
return
}
_id := make([]byte, IDLength)
n, err := reader.Read(_id)
if err != nil {
return
}
if n != int(IDLength) {
return
}
ID = string(_id)
if connType == CONN_SERVER {
var addrLength uint16
err = binary.Read(reader, binary.LittleEndian, &addrLength)
if err != nil {
return
}
_addr := make([]byte, addrLength)
n, err = reader.Read(_addr)
if err != nil {
return
}
if n != int(addrLength) {
return
}
clientLocalAddr = string(_addr)
}
}
log.Printf("connection from %s , key: %s , id: %s", connTypeStrMap[connType], key, ID)
switch connType {
case CONN_SERVER:
// hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
// log.Printf("%s conn %s from server released", key, ID)
// s.serverConns.Remove(ID)
// })
addr := clientLocalAddr + "@" + ID
s.serverConns.Set(ID, ServerConn{
//Conn: &hb,
Conn: &inConn,
ClientLocalAddr: addr,
})
for {
item, ok := s.clientControlConns.Get(key)
if !ok {
log.Printf("client %s control conn not exists", key)
time.Sleep(time.Second * 3)
continue
}
_, err := (*item.(*net.Conn)).Write([]byte(addr))
if err != nil {
log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
time.Sleep(time.Second * 3)
continue
} else {
break
}
}
case CONN_CLIENT:
serverConnItem, ok := s.serverConns.Get(ID)
if !ok {
inConn.Close()
log.Printf("server conn %s exists", ID)
return
}
serverConn := serverConnItem.(ServerConn).Conn
// hw := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hw *utils.HeartbeatReadWriter) {
// log.Printf("%s conn %s from client released", key, ID)
// hw.Close()
// })
utils.IoBind(*serverConn, inConn, func(err error) {
// utils.IoBind(serverConn, inConn, func(isSrcErr bool, err error) {
//serverConn.Close()
(*serverConn).Close()
utils.CloseConn(&inConn)
// hw.Close()
s.serverConns.Remove(ID)
log.Printf("conn %s released", ID)
}, func(i int, b bool) {}, 0)
log.Printf("conn %s created", ID)
case CONN_CONTROL:
if s.clientControlConns.Has(key) {
item, _ := s.clientControlConns.Get(key)
//(*item.(*utils.HeartbeatReadWriter)).Close()
(*item.(*net.Conn)).Close()
}
// hb := utils.NewHeartbeatReadWriter(&inConn, 3, func(err error, hb *utils.HeartbeatReadWriter) {
// log.Printf("client %s disconnected", key)
// s.clientControlConns.Remove(key)
// })
// s.clientControlConns.Set(key, &hb)
s.clientControlConns.Set(key, &inConn)
log.Printf("set client %s control conn", key)
}
})
if err != nil {
return
}
log.Printf("proxy on tunnel bridge mode %s", (*sc.Listener).Addr())
return
}
func (s *TunnelBridge) Clean() {
s.StopService()
}