diff --git a/CHANGELOG b/CHANGELOG index ba7a176..d4c7c1e 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -2,7 +2,11 @@ proxy更新日志 v4.7 1.优化了bridge的日志,增加了client和server的掉线日志. 2.优化了sps读取http(s)代理响应的缓冲大小,同时优化了CONNECT请求, -避免了某些代理服务器返回过多数据导致不能正常通讯的问题. + 避免了某些代理服务器返回过多数据导致不能正常通讯的问题. +3.去除了鸡肋连接池功能. +4.增加了gomobile sdk,对安卓/IOS提供支持. +5.优化了所有服务代码,方便对sdk提供支持. + v4.6 diff --git a/sdk/sdk.go b/sdk/sdk.go index cb37165..ee881f4 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -299,17 +299,30 @@ func Start(argsStr string) (errStr string) { } //regist services and run service - services.Regist("http", services.NewHTTP(), httpArgs) - services.Regist("tcp", services.NewTCP(), tcpArgs) - services.Regist("udp", services.NewUDP(), udpArgs) - services.Regist("tserver", services.NewTunnelServerManager(), tunnelServerArgs) - services.Regist("tclient", services.NewTunnelClient(), tunnelClientArgs) - services.Regist("tbridge", services.NewTunnelBridge(), tunnelBridgeArgs) - services.Regist("server", services.NewMuxServerManager(), muxServerArgs) - services.Regist("client", services.NewMuxClient(), muxClientArgs) - services.Regist("bridge", services.NewMuxBridge(), muxBridgeArgs) - services.Regist("socks", services.NewSocks(), socksArgs) - services.Regist("sps", services.NewSPS(), spsArgs) + switch serviceName { + case "http": + services.Regist("http", services.NewHTTP(), httpArgs) + case "tcp": + services.Regist("tcp", services.NewTCP(), tcpArgs) + case "udp": + services.Regist("udp", services.NewUDP(), udpArgs) + case "tserver": + services.Regist("tserver", services.NewTunnelServerManager(), tunnelServerArgs) + case "tclient": + services.Regist("tclient", services.NewTunnelClient(), tunnelClientArgs) + case "tbridge": + services.Regist("tbridge", services.NewTunnelBridge(), tunnelBridgeArgs) + case "server": + services.Regist("server", services.NewMuxServerManager(), muxServerArgs) + case "client": + services.Regist("client", services.NewMuxClient(), muxClientArgs) + case "bridge": + services.Regist("bridge", services.NewMuxBridge(), muxBridgeArgs) + case "socks": + services.Regist("socks", services.NewSocks(), socksArgs) + case "sps": + services.Regist("sps", services.NewSPS(), spsArgs) + } service, err = services.Run(serviceName) if err != nil { diff --git a/services/http.go b/services/http.go index 62fb269..c4da7fd 100644 --- a/services/http.go +++ b/services/http.go @@ -16,22 +16,26 @@ import ( ) type HTTP struct { - outPool utils.OutPool + outPool utils.OutConn cfg HTTPArgs checker utils.Checker basicAuth utils.BasicAuth sshClient *ssh.Client lockChn chan bool domainResolver utils.DomainResolver + isStop bool + serverChannels []*utils.ServerChannel } func NewHTTP() Service { return &HTTP{ - outPool: utils.OutPool{}, - cfg: HTTPArgs{}, - checker: utils.Checker{}, - basicAuth: utils.BasicAuth{}, - lockChn: make(chan bool, 1), + outPool: utils.OutConn{}, + cfg: HTTPArgs{}, + checker: utils.Checker{}, + basicAuth: utils.BasicAuth{}, + lockChn: make(chan bool, 1), + isStop: false, + serverChannels: []*utils.ServerChannel{}, } } func (s *HTTP) CheckArgs() (err error) { @@ -102,6 +106,9 @@ func (s *HTTP) InitService() (err error) { go func() { //循环检查ssh网络连通性 for { + if s.isStop { + return + } conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2) if err == nil { conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) @@ -127,8 +134,26 @@ func (s *HTTP) InitService() (err error) { return } func (s *HTTP) StopService() { - if s.outPool.Pool != nil { - s.outPool.Pool.ReleaseAll() + defer func() { + e := recover() + if e != nil { + log.Printf("stop http(s) service crashed,%s", e) + } else { + log.Printf("service http(s) stoped,%s", e) + } + }() + s.isStop = true + s.checker.Stop() + if s.sshClient != nil { + s.sshClient.Close() + } + for _, sc := range s.serverChannels { + if sc.Listener != nil && *sc.Listener != nil { + (*sc.Listener).Close() + } + if sc.UDPListener != nil { + (*sc.UDPListener).Close() + } } } func (s *HTTP) Start(args interface{}) (err error) { @@ -159,6 +184,7 @@ func (s *HTTP) Start(args interface{}) (err error) { return } log.Printf("%s http(s) proxy on %s", *s.cfg.LocalType, (*sc.Listener).Addr()) + s.serverChannels = append(s.serverChannels, &sc) } } return @@ -224,19 +250,18 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut return } var outConn net.Conn - var _outConn interface{} tryCount := 0 maxTryCount := 5 for { + if s.isStop { + return + } if useProxy { if *s.cfg.ParentType == "ssh" { outConn, err = s.getSSHConn(address) } else { // log.Printf("%v", s.outPool) - _outConn, err = s.outPool.Pool.Get() - if err == nil { - outConn = _outConn.(net.Conn) - } + outConn, err = s.outPool.Get() } } else { outConn, err = utils.ConnectHost(s.Resolve(address), *s.cfg.Timeout) @@ -283,7 +308,7 @@ func (s *HTTP) getSSHConn(host string) (outConn net.Conn, err interface{}) { maxTryCount := 1 tryCount := 0 RETRY: - if tryCount >= maxTryCount { + if tryCount >= maxTryCount || s.isStop { return } wait := make(chan bool, 1) @@ -340,7 +365,7 @@ func (s *HTTP) InitOutConnPool() { if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP || *s.cfg.ParentType == TYPE_KCP { //dur int, isTLS bool, certBytes, keyBytes []byte, //parent string, timeout int, InitialCap int, MaxCap int - s.outPool = utils.NewOutPool( + s.outPool = utils.NewOutConn( *s.cfg.CheckParentInterval, *s.cfg.ParentType, s.cfg.KCP, diff --git a/services/mux_bridge.go b/services/mux_bridge.go index ee4afa6..e0f8fe7 100644 --- a/services/mux_bridge.go +++ b/services/mux_bridge.go @@ -21,6 +21,8 @@ type MuxBridge struct { clientControlConns utils.ConcurrentMap router utils.ClientKeyRouter l *sync.Mutex + isStop bool + sc *utils.ServerChannel } func NewMuxBridge() Service { @@ -28,6 +30,7 @@ func NewMuxBridge() Service { cfg: MuxBridgeArgs{}, clientControlConns: utils.NewConcurrentMap(), l: &sync.Mutex{}, + isStop: false, } b.router = utils.NewClientKeyRouter(&b.clientControlConns, 50000) return b @@ -50,7 +53,23 @@ func (s *MuxBridge) CheckArgs() (err error) { return } func (s *MuxBridge) StopService() { - + defer func() { + e := recover() + if e != nil { + log.Printf("stop bridge service crashed,%s", e) + } else { + log.Printf("service bridge stoped,%s", e) + } + }() + s.isStop = true + if s.sc != nil && (*s.sc).Listener != nil { + (*(*s.sc).Listener).Close() + } + for _, g := range s.clientControlConns.Items() { + for _, session := range g.(utils.ConcurrentMap).Items() { + (session.(*smux.Session)).Close() + } + } } func (s *MuxBridge) Start(args interface{}) (err error) { s.cfg = args.(MuxBridgeArgs) @@ -74,6 +93,7 @@ func (s *MuxBridge) Start(args interface{}) (err error) { if err != nil { return } + s.sc = &sc log.Printf("%s bridge on %s", *s.cfg.LocalType, (*sc.Listener).Addr()) return } @@ -111,6 +131,9 @@ func (s *MuxBridge) handler(inConn net.Conn) { return } for { + if s.isStop { + return + } stream, err := session.AcceptStream() if err != nil { session.Close() @@ -118,7 +141,14 @@ func (s *MuxBridge) handler(inConn net.Conn) { log.Printf("server connection %s %s released", serverID, key) return } - go s.callback(stream, serverID, key) + go func() { + defer func() { + if e := recover(); e != nil { + log.Printf("bridge callback crashed,err: %s", e) + } + }() + s.callback(stream, serverID, key) + }() } case CONN_CLIENT: log.Printf("client connection %s connected", key) @@ -151,6 +181,9 @@ func (s *MuxBridge) handler(inConn net.Conn) { // s.clientControlConns.Set(key, session) go func() { for { + if s.isStop { + return + } if session.IsClosed() { s.l.Lock() defer s.l.Unlock() @@ -173,6 +206,9 @@ func (s *MuxBridge) handler(inConn net.Conn) { func (s *MuxBridge) callback(inConn net.Conn, serverID, key string) { try := 20 for { + if s.isStop { + return + } try-- if try == 0 { break diff --git a/services/mux_client.go b/services/mux_client.go index 2ea91ed..b233598 100644 --- a/services/mux_client.go +++ b/services/mux_client.go @@ -14,12 +14,16 @@ import ( ) type MuxClient struct { - cfg MuxClientArgs + cfg MuxClientArgs + isStop bool + sessions utils.ConcurrentMap } func NewMuxClient() Service { return &MuxClient{ - cfg: MuxClientArgs{}, + cfg: MuxClientArgs{}, + isStop: false, + sessions: utils.NewConcurrentMap(), } } @@ -47,7 +51,18 @@ func (s *MuxClient) CheckArgs() (err error) { return } func (s *MuxClient) StopService() { - + defer func() { + e := recover() + if e != nil { + log.Printf("stop client service crashed,%s", e) + } else { + log.Printf("service client stoped,%s", e) + } + }() + s.isStop = true + for _, sess := range s.sessions.Items() { + sess.(*smux.Session).Close() + } } func (s *MuxClient) Start(args interface{}) (err error) { s.cfg = args.(MuxClientArgs) @@ -63,7 +78,8 @@ func (s *MuxClient) Start(args interface{}) (err error) { count = *s.cfg.SessionCount } for i := 1; i <= count; i++ { - log.Printf("session worker[%d] started", i) + key := fmt.Sprintf("worker[%d]", i) + log.Printf("session %s started", key) go func(i int) { defer func() { e := recover() @@ -72,6 +88,9 @@ func (s *MuxClient) Start(args interface{}) (err error) { } }() for { + if s.isStop { + return + } conn, err := s.getParentConn() if err != nil { log.Printf("connection err: %s, retrying...", err) @@ -94,7 +113,14 @@ func (s *MuxClient) Start(args interface{}) (err error) { time.Sleep(time.Second * 3) continue } + if _sess, ok := s.sessions.Get(key); ok { + _sess.(*smux.Session).Close() + } + s.sessions.Set(key, session) for { + if s.isStop { + return + } stream, err := session.AcceptStream() if err != nil { log.Printf("accept stream err: %s, retrying...", err) @@ -153,6 +179,9 @@ func (s *MuxClient) getParentConn() (conn net.Conn, err error) { func (s *MuxClient) ServeUDP(inConn *smux.Stream, localAddr, ID string) { for { + if s.isStop { + return + } inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) srcAddr, body, err := utils.ReadUDPPacket(inConn) inConn.SetDeadline(time.Time{}) @@ -163,7 +192,15 @@ func (s *MuxClient) ServeUDP(inConn *smux.Stream, localAddr, ID string) { break } else { //log.Printf("udp packet revecived:%s,%v", srcAddr, body) - go s.processUDPPacket(inConn, srcAddr, localAddr, body) + go func() { + defer func() { + if e := recover(); e != nil { + log.Printf("client processUDPPacket crashed,err: %s", e) + } + }() + s.processUDPPacket(inConn, srcAddr, localAddr, body) + }() + } } @@ -216,6 +253,9 @@ func (s *MuxClient) ServeConn(inConn *smux.Stream, localAddr, ID string) { var outConn net.Conn i := 0 for { + if s.isStop { + return + } i++ outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout) if err == nil || i == 3 { diff --git a/services/mux_server.go b/services/mux_server.go index 2ffc6b7..27dde50 100644 --- a/services/mux_server.go +++ b/services/mux_server.go @@ -23,13 +23,15 @@ type MuxServer struct { sc utils.ServerChannel sessions utils.ConcurrentMap lockChn chan bool + isStop bool + udpConn *net.Conn } type MuxServerManager struct { cfg MuxServerArgs udpChn chan MuxUDPItem - sc utils.ServerChannel serverID string + servers []*Service } func NewMuxServerManager() Service { @@ -37,8 +39,10 @@ func NewMuxServerManager() Service { cfg: MuxServerArgs{}, udpChn: make(chan MuxUDPItem, 50000), serverID: utils.Uniqueid(), + servers: []*Service{}, } } + func (s *MuxServerManager) Start(args interface{}) (err error) { s.cfg = args.(MuxServerArgs) if err = s.CheckArgs(); err != nil { @@ -100,6 +104,7 @@ func (s *MuxServerManager) Start(args interface{}) (err error) { if err != nil { return } + s.servers = append(s.servers, &server) } return } @@ -107,6 +112,9 @@ func (s *MuxServerManager) Clean() { s.StopService() } func (s *MuxServerManager) StopService() { + for _, server := range s.servers { + (*server).Clean() + } } func (s *MuxServerManager) CheckArgs() (err error) { if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { @@ -131,6 +139,7 @@ func NewMuxServer() Service { udpChn: make(chan MuxUDPItem, 50000), lockChn: make(chan bool, 1), sessions: utils.NewConcurrentMap(), + isStop: false, } } @@ -140,6 +149,29 @@ type MuxUDPItem struct { srcAddr *net.UDPAddr } +func (s *MuxServer) StopService() { + defer func() { + e := recover() + if e != nil { + log.Printf("stop server service crashed,%s", e) + } else { + log.Printf("service server stoped,%s", e) + } + }() + s.isStop = true + for _, sess := range s.sessions.Items() { + sess.(*smux.Session).Close() + } + if s.sc.Listener != nil { + (*s.sc.Listener).Close() + } + if s.sc.UDPListener != nil { + (*s.sc.UDPListener).Close() + } + if s.udpConn != nil { + (*s.udpConn).Close() + } +} func (s *MuxServer) InitService() (err error) { s.UDPConnDeamon() return @@ -185,6 +217,9 @@ func (s *MuxServer) Start(args interface{}) (err error) { var outConn net.Conn var ID string for { + if s.isStop { + return + } outConn, ID, err = s.GetOutConn() if err != nil { utils.CloseConn(&outConn) @@ -228,7 +263,7 @@ func (s *MuxServer) Start(args interface{}) (err error) { return } func (s *MuxServer) Clean() { - + s.StopService() } func (s *MuxServer) GetOutConn() (outConn net.Conn, ID string, err error) { i := 1 @@ -286,10 +321,16 @@ func (s *MuxServer) GetConn(index string) (conn net.Conn, err error) { return } } + if _sess, ok := s.sessions.Get(index); ok { + _sess.(*smux.Session).Close() + } s.sessions.Set(index, session) log.Printf("session[%s] created", index) go func() { for { + if s.isStop { + return + } if session.IsClosed() { s.sessions.Remove(index) break @@ -332,10 +373,19 @@ func (s *MuxServer) UDPConnDeamon() { var ID string var err error for { + if s.isStop { + return + } item := <-s.udpChn RETRY: + if s.isStop { + return + } if outConn == nil { for { + if s.isStop { + return + } outConn, ID, err = s.GetOutConn() if err != nil { outConn = nil @@ -345,10 +395,14 @@ func (s *MuxServer) UDPConnDeamon() { continue } else { go func(outConn net.Conn, ID string) { - go func() { - // outConn.Close() - }() + if s.udpConn != nil { + (*s.udpConn).Close() + } + s.udpConn = &outConn for { + if s.isStop { + return + } outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) srcAddrFromConn, body, err := utils.ReadUDPPacket(outConn) outConn.SetDeadline(time.Time{}) diff --git a/services/service.go b/services/service.go index b133210..8131a68 100644 --- a/services/service.go +++ b/services/service.go @@ -18,12 +18,18 @@ type ServiceItem struct { var servicesMap = map[string]*ServiceItem{} func Regist(name string, s Service, args interface{}) { + servicesMap[name] = &ServiceItem{ S: s, Args: args, Name: name, } } +func Stop(name string) { + if s, ok := servicesMap[name]; ok && s.S != nil { + s.S.Clean() + } +} func Run(name string, args ...interface{}) (service *ServiceItem, err error) { service, ok := servicesMap[name] if ok { diff --git a/services/socks.go b/services/socks.go index 02fcaa6..a79a6c7 100644 --- a/services/socks.go +++ b/services/socks.go @@ -23,7 +23,9 @@ type Socks struct { sshClient *ssh.Client lockChn chan bool udpSC utils.ServerChannel + sc *utils.ServerChannel domainResolver utils.DomainResolver + isStop bool } func NewSocks() Service { @@ -32,6 +34,7 @@ func NewSocks() Service { checker: utils.Checker{}, basicAuth: utils.BasicAuth{}, lockChn: make(chan bool, 1), + isStop: false, } } @@ -103,6 +106,9 @@ func (s *Socks) InitService() (err error) { go func() { //循环检查ssh网络连通性 for { + if s.isStop { + return + } conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2) if err == nil { conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout))) @@ -136,12 +142,25 @@ func (s *Socks) InitService() (err error) { return } func (s *Socks) StopService() { + defer func() { + e := recover() + if e != nil { + log.Printf("stop socks service crashed,%s", e) + } else { + log.Printf("service socks stoped,%s", e) + } + }() + s.isStop = true + s.checker.Stop() if s.sshClient != nil { s.sshClient.Close() } if s.udpSC.UDPListener != nil { s.udpSC.UDPListener.Close() } + if s.sc != nil && (*s.sc).Listener != nil { + (*(*s.sc).Listener).Close() + } } func (s *Socks) Start(args interface{}) (err error) { //start() @@ -166,6 +185,7 @@ func (s *Socks) Start(args interface{}) (err error) { if err != nil { return } + s.sc = &sc log.Printf("%s socks proxy on %s", *s.cfg.LocalType, (*sc.Listener).Addr()) return } @@ -457,6 +477,9 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque return } for { + if s.isStop { + return + } if *s.cfg.Always { outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr()) } else { @@ -563,7 +586,7 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string) (outConn n maxTryCount := 1 tryCount := 0 RETRY: - if tryCount >= maxTryCount { + if tryCount >= maxTryCount || s.isStop { return } wait := make(chan bool, 1) diff --git a/services/sps.go b/services/sps.go index a169aaa..ff02bfb 100644 --- a/services/sps.go +++ b/services/sps.go @@ -17,17 +17,19 @@ import ( ) type SPS struct { - outPool utils.OutPool + outPool utils.OutConn cfg SPSArgs domainResolver utils.DomainResolver basicAuth utils.BasicAuth + serverChannels []*utils.ServerChannel } func NewSPS() Service { return &SPS{ - outPool: utils.OutPool{}, - cfg: SPSArgs{}, - basicAuth: utils.BasicAuth{}, + outPool: utils.OutConn{}, + cfg: SPSArgs{}, + basicAuth: utils.BasicAuth{}, + serverChannels: []*utils.ServerChannel{}, } } func (s *SPS) CheckArgs() (err error) { @@ -66,7 +68,7 @@ func (s *SPS) InitOutConnPool() { if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP || *s.cfg.ParentType == TYPE_KCP { //dur int, isTLS bool, certBytes, keyBytes []byte, //parent string, timeout int, InitialCap int, MaxCap int - s.outPool = utils.NewOutPool( + s.outPool = utils.NewOutConn( 0, *s.cfg.ParentType, s.cfg.KCP, @@ -80,8 +82,21 @@ func (s *SPS) InitOutConnPool() { } func (s *SPS) StopService() { - if s.outPool.Pool != nil { - s.outPool.Pool.ReleaseAll() + defer func() { + e := recover() + if e != nil { + log.Printf("stop sps service crashed,%s", e) + } else { + log.Printf("service sps stoped,%s", e) + } + }() + for _, sc := range s.serverChannels { + if sc.Listener != nil && *sc.Listener != nil { + (*sc.Listener).Close() + } + if sc.UDPListener != nil { + (*sc.UDPListener).Close() + } } } func (s *SPS) Start(args interface{}) (err error) { @@ -109,6 +124,7 @@ func (s *SPS) Start(args interface{}) (err error) { return } log.Printf("%s http(s)+socks proxy on %s", s.cfg.Protocol(), (*sc.Listener).Addr()) + s.serverChannels = append(s.serverChannels, &sc) } } return @@ -207,11 +223,7 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) { } //connect to parent var outConn net.Conn - var _outConn interface{} - _outConn, err = s.outPool.Pool.Get() - if err == nil { - outConn = _outConn.(net.Conn) - } + outConn, err = s.outPool.Get() if err != nil { log.Printf("connect to %s , err:%s", *s.cfg.Parent, err) utils.CloseConn(inConn) diff --git a/services/tcp.go b/services/tcp.go index edf0689..176d289 100644 --- a/services/tcp.go +++ b/services/tcp.go @@ -14,14 +14,17 @@ import ( ) type TCP struct { - outPool utils.OutPool + outPool utils.OutConn cfg TCPArgs + sc *utils.ServerChannel + isStop bool } func NewTCP() Service { return &TCP{ - outPool: utils.OutPool{}, + outPool: utils.OutConn{}, cfg: TCPArgs{}, + isStop: false, } } func (s *TCP) CheckArgs() (err error) { @@ -46,8 +49,20 @@ func (s *TCP) InitService() (err error) { return } func (s *TCP) StopService() { - if s.outPool.Pool != nil { - s.outPool.Pool.ReleaseAll() + defer func() { + e := recover() + if e != nil { + log.Printf("stop tcp service crashed,%s", e) + } else { + log.Printf("service tcp stoped,%s", e) + } + }() + s.isStop = true + if s.sc.Listener != nil && *s.sc.Listener != nil { + (*s.sc.Listener).Close() + } + if s.sc.UDPListener != nil { + (*s.sc.UDPListener).Close() } } func (s *TCP) Start(args interface{}) (err error) { @@ -74,6 +89,7 @@ func (s *TCP) Start(args interface{}) (err error) { return } log.Printf("%s proxy on %s", s.cfg.Protocol(), (*sc.Listener).Addr()) + s.sc = &sc return } @@ -106,11 +122,7 @@ func (s *TCP) callback(inConn net.Conn) { } func (s *TCP) OutToTCP(inConn *net.Conn) (err error) { var outConn net.Conn - var _outConn interface{} - _outConn, err = s.outPool.Pool.Get() - if err == nil { - outConn = _outConn.(net.Conn) - } + outConn, err = s.outPool.Get() if err != nil { log.Printf("connect to %s , err:%s", *s.cfg.Parent, err) utils.CloseConn(inConn) @@ -129,6 +141,9 @@ func (s *TCP) OutToTCP(inConn *net.Conn) (err error) { func (s *TCP) OutToUDP(inConn *net.Conn) (err error) { log.Printf("conn created , remote : %s ", (*inConn).RemoteAddr()) for { + if s.isStop { + return + } srcAddr, body, err := utils.ReadUDPPacket(bufio.NewReader(*inConn)) if err == io.EOF || err == io.ErrUnexpectedEOF { //log.Printf("connection %s released", srcAddr) @@ -178,7 +193,7 @@ func (s *TCP) InitOutConnPool() { if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP || *s.cfg.ParentType == TYPE_KCP { //dur int, isTLS bool, certBytes, keyBytes []byte, //parent string, timeout int, InitialCap int, MaxCap int - s.outPool = utils.NewOutPool( + s.outPool = utils.NewOutConn( *s.cfg.CheckParentInterval, *s.cfg.ParentType, s.cfg.KCP, diff --git a/services/tunnel_bridge.go b/services/tunnel_bridge.go index f7122b3..f9c6077 100644 --- a/services/tunnel_bridge.go +++ b/services/tunnel_bridge.go @@ -18,8 +18,7 @@ type TunnelBridge struct { cfg TunnelBridgeArgs serverConns utils.ConcurrentMap clientControlConns utils.ConcurrentMap - // cmServer utils.ConnManager - // cmClient utils.ConnManager + isStop bool } func NewTunnelBridge() Service { @@ -27,8 +26,7 @@ func NewTunnelBridge() Service { cfg: TunnelBridgeArgs{}, serverConns: utils.NewConcurrentMap(), clientControlConns: utils.NewConcurrentMap(), - // cmServer: utils.NewConnManager(), - // cmClient: utils.NewConnManager(), + isStop: false, } } @@ -44,7 +42,21 @@ func (s *TunnelBridge) CheckArgs() (err error) { return } func (s *TunnelBridge) StopService() { - + defer func() { + e := recover() + if e != nil { + log.Printf("stop tbridge service crashed,%s", e) + } else { + log.Printf("service tbridge stoped,%s", e) + } + }() + s.isStop = true + for _, sess := range s.clientControlConns.Items() { + (*sess.(*net.Conn)).Close() + } + for _, sess := range s.serverConns.Items() { + (*sess.(ServerConn).Conn).Close() + } } func (s *TunnelBridge) Start(args interface{}) (err error) { s.cfg = args.(TunnelBridgeArgs) @@ -85,6 +97,9 @@ func (s *TunnelBridge) Start(args interface{}) (err error) { Conn: &inConn, }) for { + if s.isStop { + return + } item, ok := s.clientControlConns.Get(key) if !ok { log.Printf("client %s control conn not exists", key) diff --git a/services/tunnel_client.go b/services/tunnel_client.go index 012756f..4f0f51a 100644 --- a/services/tunnel_client.go +++ b/services/tunnel_client.go @@ -14,12 +14,14 @@ type TunnelClient struct { cfg TunnelClientArgs // cm utils.ConnManager ctrlConn net.Conn + isStop bool } func NewTunnelClient() Service { return &TunnelClient{ cfg: TunnelClientArgs{}, // cm: utils.NewConnManager(), + isStop: false, } } @@ -42,7 +44,18 @@ func (s *TunnelClient) CheckArgs() (err error) { return } func (s *TunnelClient) StopService() { - // s.cm.RemoveAll() + defer func() { + e := recover() + if e != nil { + log.Printf("stop tclient service crashed,%s", e) + } else { + log.Printf("service tclient stoped,%s", e) + } + }() + s.isStop = true + if s.ctrlConn != nil { + s.ctrlConn.Close() + } } func (s *TunnelClient) Start(args interface{}) (err error) { s.cfg = args.(TunnelClientArgs) @@ -55,8 +68,9 @@ func (s *TunnelClient) Start(args interface{}) (err error) { log.Printf("proxy on tunnel client mode") for { - //close all conn - // s.cm.Remove(*s.cfg.Key) + if s.isStop { + return + } if s.ctrlConn != nil { s.ctrlConn.Close() } @@ -71,6 +85,9 @@ func (s *TunnelClient) Start(args interface{}) (err error) { continue } for { + if s.isStop { + return + } var ID, clientLocalAddr, serverID string err = utils.ReadPacketData(s.ctrlConn, &ID, &clientLocalAddr, &serverID) if err != nil { @@ -121,6 +138,9 @@ func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) { var err error // for { for { + if s.isStop { + return + } // s.cm.RemoveOne(*s.cfg.Key, ID) inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) if err != nil { @@ -136,6 +156,9 @@ func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) { log.Printf("conn %s created", ID) for { + if s.isStop { + return + } srcAddr, body, err := utils.ReadUDPPacket(inConn) if err == io.EOF || err == io.ErrUnexpectedEOF { log.Printf("connection %s released", ID) @@ -192,6 +215,9 @@ func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) { var inConn, outConn net.Conn var err error for { + if s.isStop { + return + } inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID) if err != nil { utils.CloseConn(&inConn) @@ -205,6 +231,9 @@ func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) { i := 0 for { + if s.isStop { + return + } i++ outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout) if err == nil || i == 3 { diff --git a/services/tunnel_server.go b/services/tunnel_server.go index e3ed0cf..d0dcc16 100644 --- a/services/tunnel_server.go +++ b/services/tunnel_server.go @@ -14,17 +14,18 @@ import ( ) type TunnelServer struct { - cfg TunnelServerArgs - udpChn chan UDPItem - sc utils.ServerChannel + cfg TunnelServerArgs + udpChn chan UDPItem + sc utils.ServerChannel + isStop bool + udpConn *net.Conn } type TunnelServerManager struct { cfg TunnelServerArgs udpChn chan UDPItem - sc utils.ServerChannel serverID string - // cm utils.ConnManager + servers []*Service } func NewTunnelServerManager() Service { @@ -32,7 +33,7 @@ func NewTunnelServerManager() Service { cfg: TunnelServerArgs{}, udpChn: make(chan UDPItem, 50000), serverID: utils.Uniqueid(), - // cm: utils.NewConnManager(), + servers: []*Service{}, } } func (s *TunnelServerManager) Start(args interface{}) (err error) { @@ -89,6 +90,7 @@ func (s *TunnelServerManager) Start(args interface{}) (err error) { if err != nil { return } + s.servers = append(s.servers, &server) } return } @@ -96,7 +98,9 @@ func (s *TunnelServerManager) Clean() { s.StopService() } func (s *TunnelServerManager) StopService() { - // s.cm.RemoveAll() + for _, server := range s.servers { + (*server).Clean() + } } func (s *TunnelServerManager) CheckArgs() (err error) { if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" { @@ -137,6 +141,7 @@ func NewTunnelServer() Service { return &TunnelServer{ cfg: TunnelServerArgs{}, udpChn: make(chan UDPItem, 50000), + isStop: false, } } @@ -146,6 +151,27 @@ type UDPItem struct { srcAddr *net.UDPAddr } +func (s *TunnelServer) StopService() { + defer func() { + e := recover() + if e != nil { + log.Printf("stop server service crashed,%s", e) + } else { + log.Printf("service server stoped,%s", e) + } + }() + s.isStop = true + + if s.sc.Listener != nil { + (*s.sc.Listener).Close() + } + if s.sc.UDPListener != nil { + (*s.sc.UDPListener).Close() + } + if s.udpConn != nil { + (*s.udpConn).Close() + } +} func (s *TunnelServer) InitService() (err error) { s.UDPConnDeamon() return @@ -191,6 +217,9 @@ func (s *TunnelServer) Start(args interface{}) (err error) { var outConn net.Conn var ID string for { + if s.isStop { + return + } outConn, ID, err = s.GetOutConn(CONN_SERVER) if err != nil { utils.CloseConn(&outConn) @@ -259,10 +288,19 @@ func (s *TunnelServer) UDPConnDeamon() { // var cmdChn = make(chan bool, 1000) var err error for { + if s.isStop { + return + } item := <-s.udpChn RETRY: + if s.isStop { + return + } if outConn == nil { for { + if s.isStop { + return + } outConn, ID, err = s.GetOutConn(CONN_SERVER) if err != nil { // cmdChn <- true @@ -273,11 +311,14 @@ func (s *TunnelServer) UDPConnDeamon() { continue } else { go func(outConn net.Conn, ID string) { - go func() { - // <-cmdChn - // outConn.Close() - }() + if s.udpConn != nil { + (*s.udpConn).Close() + } + s.udpConn = &outConn for { + if s.isStop { + return + } srcAddrFromConn, body, err := utils.ReadUDPPacket(outConn) if err == io.EOF || err == io.ErrUnexpectedEOF { log.Printf("UDP deamon connection %s exited", ID) diff --git a/services/udp.go b/services/udp.go index 400b252..09f6aeb 100644 --- a/services/udp.go +++ b/services/udp.go @@ -17,15 +17,17 @@ import ( type UDP struct { p utils.ConcurrentMap - outPool utils.OutPool + outPool utils.OutConn cfg UDPArgs sc *utils.ServerChannel + isStop bool } func NewUDP() Service { return &UDP{ - outPool: utils.OutPool{}, + outPool: utils.OutConn{}, p: utils.NewConcurrentMap(), + isStop: false, } } func (s *UDP) CheckArgs() (err error) { @@ -52,8 +54,20 @@ func (s *UDP) InitService() (err error) { return } func (s *UDP) StopService() { - if s.outPool.Pool != nil { - s.outPool.Pool.ReleaseAll() + defer func() { + e := recover() + if e != nil { + log.Printf("stop udp service crashed,%s", e) + } else { + log.Printf("service udp stoped,%s", e) + } + }() + s.isStop = true + if s.sc.Listener != nil && *s.sc.Listener != nil { + (*s.sc.Listener).Close() + } + if s.sc.UDPListener != nil { + (*s.sc.UDPListener).Close() } } func (s *UDP) Start(args interface{}) (err error) { @@ -105,7 +119,7 @@ func (s *UDP) GetConn(connKey string) (conn net.Conn, isNew bool, err error) { isNew = !s.p.Has(connKey) var _conn interface{} if isNew { - _conn, err = s.outPool.Pool.Get() + _conn, err = s.outPool.Get() if err != nil { return nil, false, err } @@ -138,6 +152,9 @@ func (s *UDP) OutToTCP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err erro }() log.Printf("conn %d created , local: %s", connKey, srcAddr.String()) for { + if s.isStop { + return + } srcAddrFromConn, body, err := utils.ReadUDPPacket(bufio.NewReader(conn)) if err == io.EOF || err == io.ErrUnexpectedEOF { //log.Printf("connection %d released", connKey) @@ -216,7 +233,7 @@ func (s *UDP) InitOutConnPool() { if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP { //dur int, isTLS bool, certBytes, keyBytes []byte, //parent string, timeout int, InitialCap int, MaxCap int - s.outPool = utils.NewOutPool( + s.outPool = utils.NewOutConn( *s.cfg.CheckParentInterval, *s.cfg.ParentType, kcpcfg.KCPConfigArgs{}, diff --git a/utils/pool.go b/utils/pool.go deleted file mode 100755 index ae30f6f..0000000 --- a/utils/pool.go +++ /dev/null @@ -1,145 +0,0 @@ -package utils - -import ( - "log" - "sync" - "time" -) - -//ConnPool to use -type ConnPool interface { - Get() (conn interface{}, err error) - Put(conn interface{}) - ReleaseAll() - Len() (length int) -} -type poolConfig struct { - Factory func() (interface{}, error) - IsActive func(interface{}) bool - Release func(interface{}) - InitialCap int - MaxCap int -} - -func NewConnPool(poolConfig poolConfig) (pool ConnPool, err error) { - p := netPool{ - config: poolConfig, - conns: make(chan interface{}, poolConfig.MaxCap), - lock: &sync.Mutex{}, - } - //log.Printf("pool MaxCap:%d", poolConfig.MaxCap) - if poolConfig.MaxCap > 0 { - err = p.initAutoFill(false) - if err == nil { - p.initAutoFill(true) - } - } - return &p, nil -} - -type netPool struct { - conns chan interface{} - lock *sync.Mutex - config poolConfig -} - -func (p *netPool) initAutoFill(async bool) (err error) { - var worker = func() (err error) { - for { - //log.Printf("pool fill: %v , len: %d", p.Len() <= p.config.InitialCap/2, p.Len()) - if p.Len() <= p.config.InitialCap/2 { - p.lock.Lock() - errN := 0 - for i := 0; i < p.config.InitialCap; i++ { - c, err := p.config.Factory() - if err != nil { - errN++ - if async { - continue - } else { - p.lock.Unlock() - return err - } - } - select { - case p.conns <- c: - default: - p.config.Release(c) - break - } - if p.Len() >= p.config.InitialCap { - break - } - } - if errN > 0 { - log.Printf("fill conn pool fail , ERRN:%d", errN) - } - p.lock.Unlock() - } - if !async { - return - } - time.Sleep(time.Second * 2) - } - } - if async { - go worker() - } else { - err = worker() - } - return - -} - -func (p *netPool) Get() (conn interface{}, err error) { - // defer func() { - // log.Printf("pool len : %d", p.Len()) - // }() - p.lock.Lock() - defer p.lock.Unlock() - // for { - select { - case conn = <-p.conns: - if p.config.IsActive(conn) { - return - } - p.config.Release(conn) - default: - conn, err = p.config.Factory() - if err != nil { - return nil, err - } - return conn, nil - } - // } - return -} - -func (p *netPool) Put(conn interface{}) { - if conn == nil { - return - } - p.lock.Lock() - defer p.lock.Unlock() - if !p.config.IsActive(conn) { - p.config.Release(conn) - } - select { - case p.conns <- conn: - default: - p.config.Release(conn) - } -} -func (p *netPool) ReleaseAll() { - p.lock.Lock() - defer p.lock.Unlock() - close(p.conns) - for c := range p.conns { - p.config.Release(c) - } - p.conns = make(chan interface{}, p.config.InitialCap) - -} -func (p *netPool) Len() (length int) { - return len(p.conns) -} diff --git a/utils/structs.go b/utils/structs.go index e7f3958..6279530 100644 --- a/utils/structs.go +++ b/utils/structs.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/tls" "encoding/base64" - "encoding/binary" "errors" "fmt" "io" @@ -28,6 +27,7 @@ type Checker struct { directMap ConcurrentMap interval int64 timeout int + isStop bool } type CheckerItem struct { IsHTTPS bool @@ -48,6 +48,7 @@ func NewChecker(timeout int, interval int64, blockedFile, directFile string) Che data: NewConcurrentMap(), interval: interval, timeout: timeout, + isStop: false, } ch.blockedMap = ch.loadMap(blockedFile) ch.directMap = ch.loadMap(directFile) @@ -81,6 +82,9 @@ func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) { } return } +func (c *Checker) Stop() { + c.isStop = true +} func (c *Checker) start() { go func() { //log.Printf("checker started") @@ -107,6 +111,9 @@ func (c *Checker) start() { }(v.(CheckerItem)) } time.Sleep(time.Second * time.Duration(c.interval)) + if c.isStop { + return + } } }() } @@ -498,8 +505,7 @@ func (req *HTTPRequest) addPortIfNot() (newHost string) { return } -type OutPool struct { - Pool ConnPool +type OutConn struct { dur int typ string certBytes []byte @@ -510,8 +516,8 @@ type OutPool struct { timeout int } -func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyBytes, caCertBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) { - op = OutPool{ +func NewOutConn(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyBytes, caCertBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutConn) { + return OutConn{ dur: dur, typ: typ, certBytes: certBytes, @@ -521,36 +527,8 @@ func NewOutPool(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyByt address: address, timeout: timeout, } - var err error - op.Pool, err = NewConnPool(poolConfig{ - IsActive: func(conn interface{}) bool { return true }, - Release: func(conn interface{}) { - if conn != nil { - conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond)) - conn.(net.Conn).Close() - // log.Println("conn released") - } - }, - InitialCap: InitialCap, - MaxCap: MaxCap, - Factory: func() (conn interface{}, err error) { - conn, err = op.getConn() - return - }, - }) - if err != nil { - log.Fatalf("init conn pool fail ,%s", err) - } else { - if InitialCap > 0 { - log.Printf("init conn pool success") - op.initPoolDeamon() - } else { - log.Printf("conn pool closed") - } - } - return } -func (op *OutPool) getConn() (conn interface{}, err error) { +func (op *OutConn) Get() (conn net.Conn, err error) { if op.typ == "tls" { var _conn tls.Conn _conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes, op.caCertBytes) @@ -565,176 +543,6 @@ func (op *OutPool) getConn() (conn interface{}, err error) { return } -func (op *OutPool) initPoolDeamon() { - go func() { - if op.dur <= 0 { - return - } - log.Printf("pool deamon started") - for { - time.Sleep(time.Second * time.Duration(op.dur)) - conn, err := op.getConn() - if err != nil { - log.Printf("pool deamon err %s , release pool", err) - op.Pool.ReleaseAll() - } else { - conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond)) - conn.(net.Conn).Close() - } - } - }() -} - -type HeartbeatData struct { - Data []byte - N int - Error error -} -type HeartbeatReadWriter struct { - conn *net.Conn - // rchn chan HeartbeatData - l *sync.Mutex - dur int - errHandler func(err error, hb *HeartbeatReadWriter) - once *sync.Once - datachn chan byte - // rbuf bytes.Buffer - // signal chan bool - rerrchn chan error -} - -func NewHeartbeatReadWriter(conn *net.Conn, dur int, fn func(err error, hb *HeartbeatReadWriter)) (hrw HeartbeatReadWriter) { - hrw = HeartbeatReadWriter{ - conn: conn, - l: &sync.Mutex{}, - dur: dur, - // rchn: make(chan HeartbeatData, 10000), - // signal: make(chan bool, 1), - errHandler: fn, - datachn: make(chan byte, 4*1024), - once: &sync.Once{}, - rerrchn: make(chan error, 1), - // rbuf: bytes.Buffer{}, - } - hrw.heartbeat() - hrw.reader() - return -} - -func (rw *HeartbeatReadWriter) Close() { - CloseConn(rw.conn) -} -func (rw *HeartbeatReadWriter) reader() { - go func() { - //log.Printf("heartbeat read started") - for { - n, data, err := rw.read() - if n == -1 { - continue - } - //log.Printf("n:%d , data:%s ,err:%s", n, string(data), err) - if err == nil { - //fmt.Printf("write data %s\n", string(data)) - for _, b := range data { - rw.datachn <- b - } - } - if err != nil { - //log.Printf("heartbeat reader err: %s", err) - select { - case rw.rerrchn <- err: - default: - } - rw.once.Do(func() { - rw.errHandler(err, rw) - }) - break - } - } - //log.Printf("heartbeat read exited") - }() -} -func (rw *HeartbeatReadWriter) read() (n int, data []byte, err error) { - var typ uint8 - err = binary.Read((*rw.conn), binary.LittleEndian, &typ) - if err != nil { - return - } - if typ == 0 { - // log.Printf("heartbeat revecived") - n = -1 - return - } - var dataLength uint32 - binary.Read((*rw.conn), binary.LittleEndian, &dataLength) - _data := make([]byte, dataLength) - // log.Printf("dataLength:%d , data:%s", dataLength, string(data)) - n, err = (*rw.conn).Read(_data) - //log.Printf("n:%d , data:%s ,err:%s", n, string(data), err) - if err != nil { - return - } - if uint32(n) != dataLength { - err = fmt.Errorf("read short data body") - return - } - data = _data[:n] - return -} -func (rw *HeartbeatReadWriter) heartbeat() { - go func() { - //log.Printf("heartbeat started") - for { - if rw.conn == nil || *rw.conn == nil { - //log.Printf("heartbeat err: conn nil") - break - } - rw.l.Lock() - _, err := (*rw.conn).Write([]byte{0}) - rw.l.Unlock() - if err != nil { - //log.Printf("heartbeat err: %s", err) - rw.once.Do(func() { - rw.errHandler(err, rw) - }) - break - } else { - // log.Printf("heartbeat send ok") - } - time.Sleep(time.Second * time.Duration(rw.dur)) - } - //log.Printf("heartbeat exited") - }() -} -func (rw *HeartbeatReadWriter) Read(p []byte) (n int, err error) { - data := make([]byte, cap(p)) - for i := 0; i < cap(p); i++ { - data[i] = <-rw.datachn - n++ - //fmt.Printf("read %d %v\n", i, data[:n]) - if len(rw.datachn) == 0 { - n = i + 1 - copy(p, data[:n]) - return - } - } - return -} -func (rw *HeartbeatReadWriter) Write(p []byte) (n int, err error) { - defer rw.l.Unlock() - rw.l.Lock() - pkg := new(bytes.Buffer) - binary.Write(pkg, binary.LittleEndian, uint8(1)) - binary.Write(pkg, binary.LittleEndian, uint32(len(p))) - binary.Write(pkg, binary.LittleEndian, p) - bs := pkg.Bytes() - n, err = (*rw.conn).Write(bs) - if err == nil { - n = len(p) - } - return -} - type ConnManager struct { pool ConcurrentMap l *sync.Mutex