goproxy/services/mux_bridge.go
2018-02-26 10:33:12 +08:00

162 lines
3.5 KiB
Go

package services
import (
"bufio"
"io"
"log"
"net"
"snail007/proxy/utils"
"strconv"
"time"
"github.com/xtaci/smux"
)
type MuxBridge struct {
cfg MuxBridgeArgs
clientControlConns utils.ConcurrentMap
router utils.ClientKeyRouter
}
func NewMuxBridge() Service {
b := &MuxBridge{
cfg: MuxBridgeArgs{},
clientControlConns: utils.NewConcurrentMap(),
}
b.router = utils.NewClientKeyRouter(&b.clientControlConns, 50000)
return b
}
func (s *MuxBridge) InitService() {
}
func (s *MuxBridge) CheckArgs() {
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
log.Fatalf("cert and key file required")
}
s.cfg.CertBytes, s.cfg.KeyBytes = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
}
func (s *MuxBridge) StopService() {
}
func (s *MuxBridge) Start(args interface{}) (err error) {
s.cfg = args.(MuxBridgeArgs)
s.CheckArgs()
s.InitService()
host, port, _ := net.SplitHostPort(*s.cfg.Local)
p, _ := strconv.Atoi(port)
sc := utils.NewServerChannel(host, p)
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) {
reader := bufio.NewReader(inConn)
var err error
var connType uint8
var key string
err = utils.ReadPacket(reader, &connType, &key)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
switch connType {
case CONN_SERVER:
var serverID string
err = utils.ReadPacketData(reader, &serverID)
if err != nil {
log.Printf("read error,ERR:%s", err)
return
}
log.Printf("server connection %s %s connected", serverID, key)
session, err := smux.Server(inConn, nil)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("server session error,ERR:%s", err)
return
}
for {
stream, err := session.AcceptStream()
if err != nil {
session.Close()
utils.CloseConn(&inConn)
return
}
go s.callback(stream, serverID, key)
}
case CONN_CLIENT:
log.Printf("client connection %s connected", key)
session, err := smux.Client(inConn, nil)
if err != nil {
utils.CloseConn(&inConn)
log.Printf("client session error,ERR:%s", err)
return
}
s.clientControlConns.Set(key, session)
go func() {
for {
if session.IsClosed() {
s.clientControlConns.Remove(key)
break
}
time.Sleep(time.Second * 5)
}
}()
//log.Printf("set client session,key: %s", key)
}
})
if err != nil {
return
}
log.Printf("proxy on mux bridge mode %s", (*sc.Listener).Addr())
return
}
func (s *MuxBridge) Clean() {
s.StopService()
}
func (s *MuxBridge) callback(inConn net.Conn, serverID, key string) {
try := 20
for {
try--
if try == 0 {
break
}
if key == "*" {
key = s.router.GetKey()
}
session, ok := s.clientControlConns.Get(key)
if !ok {
log.Printf("client %s session not exists for server stream %s", key, serverID)
time.Sleep(time.Second * 3)
continue
}
stream, err := session.(*smux.Session).OpenStream()
if err != nil {
log.Printf("%s client session open stream %s fail, err: %s, retrying...", key, serverID, err)
time.Sleep(time.Second * 3)
continue
} else {
log.Printf("%s server %s stream created", key, serverID)
die1 := make(chan bool, 1)
die2 := make(chan bool, 1)
go func() {
io.Copy(stream, inConn)
die1 <- true
}()
go func() {
io.Copy(inConn, stream)
die2 <- true
}()
select {
case <-die1:
case <-die2:
}
stream.Close()
inConn.Close()
log.Printf("%s server %s stream released", key, serverID)
break
}
}
}