diff --git a/core/cs/client/client.go b/core/cs/client/client.go new file mode 100644 index 0000000..034ad21 --- /dev/null +++ b/core/cs/client/client.go @@ -0,0 +1,132 @@ +package client + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "strconv" + "strings" + "time" + + "github.com/snail007/goproxy/core/lib/kcpcfg" + compressconn "github.com/snail007/goproxy/core/lib/transport" + encryptconn "github.com/snail007/goproxy/core/lib/transport/encrypt" + "github.com/snail007/goproxy/core/dst" + kcp "github.com/xtaci/kcp-go" +) + +func TlsConnectHost(host string, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) { + h := strings.Split(host, ":") + port, _ := strconv.Atoi(h[1]) + return TlsConnect(h[0], port, timeout, certBytes, keyBytes, caCertBytes) +} + +func TlsConnect(host string, port, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) { + conf, err := getRequestTlsConfig(certBytes, keyBytes, caCertBytes) + if err != nil { + return + } + _conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond) + if err != nil { + return + } + return *tls.Client(_conn, conf), err +} +func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) { + + var cert tls.Certificate + cert, err = tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + return + } + serverCertPool := x509.NewCertPool() + caBytes := certBytes + if caCertBytes != nil { + caBytes = caCertBytes + + } + ok := serverCertPool.AppendCertsFromPEM(caBytes) + if !ok { + err = errors.New("failed to parse root certificate") + } + block, _ := pem.Decode(caBytes) + if block == nil { + panic("failed to parse certificate PEM") + } + x509Cert, _ := x509.ParseCertificate(block.Bytes) + if x509Cert == nil { + panic("failed to parse block") + } + conf = &tls.Config{ + RootCAs: serverCertPool, + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + ServerName: x509Cert.Subject.CommonName, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + opts := x509.VerifyOptions{ + Roots: serverCertPool, + } + for _, rawCert := range rawCerts { + cert, _ := x509.ParseCertificate(rawCert) + _, err := cert.Verify(opts) + if err != nil { + return err + } + } + return nil + }, + } + return +} + +func TCPConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) { + conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond) + return +} + +func TCPSConnectHost(hostAndPort string, method, password string, compress bool, timeout int) (conn net.Conn, err error) { + conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond) + if err != nil { + return + } + if compress { + conn = compressconn.NewCompConn(conn) + } + conn, err = encryptconn.NewConn(conn, method, password) + return +} + +func TOUConnectHost(hostAndPort string, method, password string, compress bool, timeout int) (conn net.Conn, err error) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + panic(err) + } + // Create a DST mux around the packet connection with the default max + // packet size. + mux := dst.NewMux(udpConn, 0) + conn, err = mux.Dial("dst", hostAndPort) + if compress { + conn = compressconn.NewCompConn(conn) + } + conn, err = encryptconn.NewConn(conn, method, password) + return +} +func KCPConnectHost(hostAndPort string, config kcpcfg.KCPConfigArgs) (conn net.Conn, err error) { + kcpconn, err := kcp.DialWithOptions(hostAndPort, config.Block, *config.DataShard, *config.ParityShard) + if err != nil { + return + } + kcpconn.SetStreamMode(true) + kcpconn.SetWriteDelay(true) + kcpconn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion) + kcpconn.SetMtu(*config.MTU) + kcpconn.SetWindowSize(*config.SndWnd, *config.RcvWnd) + kcpconn.SetACKNoDelay(*config.AckNodelay) + if *config.NoComp { + return kcpconn, err + } + return compressconn.NewCompStream(kcpconn), err +} diff --git a/core/cs/server/server.go b/core/cs/server/server.go new file mode 100644 index 0000000..3a46b9f --- /dev/null +++ b/core/cs/server/server.go @@ -0,0 +1,342 @@ +package server + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + + logger "log" + "net" + "runtime/debug" + "strconv" + + compressconn "github.com/snail007/goproxy/core/lib/transport" + transportc "github.com/snail007/goproxy/core/lib/transport" + encryptconn "github.com/snail007/goproxy/core/lib/transport/encrypt" + tou "github.com/snail007/goproxy/core/dst" + + "github.com/snail007/goproxy/core/lib/kcpcfg" + + kcp "github.com/xtaci/kcp-go" +) + +func init() { + +} + +type ServerChannel struct { + ip string + port int + Listener *net.Listener + UDPListener *net.UDPConn + errAcceptHandler func(err error) + log *logger.Logger + TOUServer *tou.Mux +} + +func NewServerChannel(ip string, port int, log *logger.Logger) ServerChannel { + return ServerChannel{ + ip: ip, + port: port, + log: log, + errAcceptHandler: func(err error) { + log.Printf("accept error , ERR:%s", err) + }, + } +} +func NewServerChannelHost(host string, log *logger.Logger) ServerChannel { + h, port, _ := net.SplitHostPort(host) + p, _ := strconv.Atoi(port) + return ServerChannel{ + ip: h, + port: p, + log: log, + errAcceptHandler: func(err error) { + log.Printf("accept error , ERR:%s", err) + }, + } +} +func (s *ServerChannel) SetErrAcceptHandler(fn func(err error)) { + s.errAcceptHandler = fn +} +func (s *ServerChannel) ListenSingleTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) { + return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, true) + +} +func (s *ServerChannel) ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) { + return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, false) +} +func (s *ServerChannel) _ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn), single bool) (err error) { + s.Listener, err = s.listenTLS(s.ip, s.port, certBytes, keyBytes, caCertBytes, single) + if err == nil { + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("ListenTLS crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + for { + var conn net.Conn + conn, err = (*s.Listener).Accept() + if err == nil { + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("tls connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + fn(conn) + }() + } else { + s.errAcceptHandler(err) + (*s.Listener).Close() + break + } + } + }() + } + return +} +func (s *ServerChannel) listenTLS(ip string, port int, certBytes, keyBytes, caCertBytes []byte, single bool) (ln *net.Listener, err error) { + var cert tls.Certificate + cert, err = tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + return + } + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + if !single { + clientCertPool := x509.NewCertPool() + caBytes := certBytes + if caCertBytes != nil { + caBytes = caCertBytes + } + ok := clientCertPool.AppendCertsFromPEM(caBytes) + if !ok { + err = errors.New("failed to parse root certificate") + } + config.ClientCAs = clientCertPool + config.ClientAuth = tls.RequireAndVerifyClientCert + } + _ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config) + if err == nil { + ln = &_ln + } + return +} +func (s *ServerChannel) ListenTCPS(method, password string, compress bool, fn func(conn net.Conn)) (err error) { + _, err = encryptconn.NewCipher(method, password) + if err != nil { + return + } + return s.ListenTCP(func(c net.Conn) { + if compress { + c = transportc.NewCompConn(c) + } + c, _ = encryptconn.NewConn(c, method, password) + fn(c) + }) +} +func (s *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) { + var l net.Listener + l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.ip, s.port)) + if err == nil { + s.Listener = &l + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + for { + var conn net.Conn + conn, err = (*s.Listener).Accept() + if err == nil { + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + fn(conn) + }() + } else { + s.errAcceptHandler(err) + (*s.Listener).Close() + break + } + } + }() + } + return +} +func (s *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) { + addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port} + l, err := net.ListenUDP("udp", addr) + if err == nil { + s.UDPListener = l + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + for { + var buf = make([]byte, 2048) + n, srcAddr, err := (*s.UDPListener).ReadFromUDP(buf) + if err == nil { + packet := buf[0:n] + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + fn(packet, addr, srcAddr) + }() + } else { + s.errAcceptHandler(err) + (*s.UDPListener).Close() + break + } + } + }() + } + return +} +func (s *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net.Conn), log *logger.Logger) (err error) { + lis, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", s.ip, s.port), config.Block, *config.DataShard, *config.ParityShard) + if err == nil { + if err = lis.SetDSCP(*config.DSCP); err != nil { + log.Println("SetDSCP:", err) + return + } + if err = lis.SetReadBuffer(*config.SockBuf); err != nil { + log.Println("SetReadBuffer:", err) + return + } + if err = lis.SetWriteBuffer(*config.SockBuf); err != nil { + log.Println("SetWriteBuffer:", err) + return + } + s.Listener = new(net.Listener) + *s.Listener = lis + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("ListenKCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + for { + //var conn net.Conn + conn, err := lis.AcceptKCP() + if err == nil { + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("kcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + conn.SetStreamMode(true) + conn.SetWriteDelay(true) + conn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion) + conn.SetMtu(*config.MTU) + conn.SetWindowSize(*config.SndWnd, *config.RcvWnd) + conn.SetACKNoDelay(*config.AckNodelay) + if *config.NoComp { + fn(conn) + } else { + cconn := transportc.NewCompStream(conn) + fn(cconn) + } + }() + } else { + s.errAcceptHandler(err) + (*s.Listener).Close() + break + } + } + }() + } + return +} + +func (s *ServerChannel) ListenTOU(method, password string, compress bool, fn func(conn net.Conn)) (err error) { + addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port} + s.UDPListener, err = net.ListenUDP("udp", addr) + if err != nil { + s.log.Println(err) + return + } + s.TOUServer = tou.NewMux(s.UDPListener, 0) + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("ListenRUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + for { + var conn net.Conn + conn, err = (*s.TOUServer).Accept() + if err == nil { + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack())) + } + }() + if compress { + conn = compressconn.NewCompConn(conn) + } + conn, err = encryptconn.NewConn(conn, method, password) + if err != nil { + conn.Close() + s.log.Println(err) + return + } + fn(conn) + }() + } else { + s.errAcceptHandler(err) + s.TOUServer.Close() + s.UDPListener.Close() + break + } + } + }() + + return +} +func (s *ServerChannel) Close() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack())) + } + }() + if s.Listener != nil && *s.Listener != nil { + (*s.Listener).Close() + } + if s.TOUServer != nil { + s.TOUServer.Close() + } + if s.UDPListener != nil { + s.UDPListener.Close() + } +} +func (s *ServerChannel) Addr() string { + defer func() { + if e := recover(); e != nil { + s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack())) + } + }() + if s.Listener != nil && *s.Listener != nil { + return (*s.Listener).Addr().String() + } + + if s.UDPListener != nil { + return s.UDPListener.LocalAddr().String() + } + return "" +} diff --git a/core/cs/tests/transport_test.go b/core/cs/tests/transport_test.go new file mode 100644 index 0000000..d530314 --- /dev/null +++ b/core/cs/tests/transport_test.go @@ -0,0 +1,49 @@ +package tests + +import ( + "log" + "net" + "os" + "testing" + + ctransport "github.com/snail007/goproxy/core/cs/client" + stransport "github.com/snail007/goproxy/core/cs/server" +) + +func TestTCPS(t *testing.T) { + l := log.New(os.Stderr, "", log.LstdFlags) + s := stransport.NewServerChannelHost(":", l) + err := s.ListenTCPS("aes-256-cfb", "password", true, func(inconn net.Conn) { + buf := make([]byte, 2048) + _, err := inconn.Read(buf) + if err != nil { + t.Error(err) + return + } + _, err = inconn.Write([]byte("okay")) + if err != nil { + t.Error(err) + return + } + }) + if err != nil { + t.Fatal(err) + } + client, err := ctransport.TCPSConnectHost((*s.Listener).Addr().String(), "aes-256-cfb", "password", true, 1000) + if err != nil { + t.Fatal(err) + } + defer client.Close() + _, err = client.Write([]byte("test")) + if err != nil { + t.Fatal(err) + } + b := make([]byte, 20) + n, err := client.Read(b) + if err != nil { + t.Fatal(err) + } + if string(b[:n]) != "okay" { + t.Fatalf("client revecive okay excepted,revecived : %s", string(b[:n])) + } +} diff --git a/core/dst/conn.go b/core/dst/conn.go new file mode 100644 index 0000000..3e63341 --- /dev/null +++ b/core/dst/conn.go @@ -0,0 +1,586 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + "bytes" + crand "crypto/rand" + "encoding/binary" + "fmt" + "io" + + "math/rand" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + defExpTime = 100 * time.Millisecond // N * (4 * RTT + RTTVar + SYN) + expCountClose = 8 // close connection after this many Exps + minTimeClose = 5 * time.Second // if at least this long has passed + maxInputBuffer = 8 << 20 // bytes + muxBufferPackets = 128 // buffer size of channel between mux and reader routine + rttMeasureWindow = 32 // number of packets to track for RTT averaging + rttMeasureSample = 128 // Sample every ... packet for RTT + + // number of bytes to subtract from MTU when chunking data, to try to + // avoid fragmentation + sliceOverhead = 8 /*pppoe, similar*/ + 20 /*ipv4*/ + 8 /*udp*/ + 16 /*dst*/ +) + +func init() { + // Properly seed the random number generator that we use for sequence + // numbers and stuff. + buf := make([]byte, 8) + if n, err := crand.Read(buf); n != 8 || err != nil { + panic("init random failure") + } + rand.Seed(int64(binary.BigEndian.Uint64(buf))) +} + +// TODO: export this interface when it's usable from the outside +type congestionController interface { + Ack() + NegAck() + Exp() + SendWindow() int + PacketRate() int // PPS + UpdateRTT(time.Duration) +} + +// Conn is an SDT connection carried over a Mux. +type Conn struct { + // Set at creation, thereafter immutable: + + mux *Mux + dst net.Addr + connID connectionID + remoteConnID connectionID + in chan packet + cc congestionController + packetSize int + closed chan struct{} + closeOnce sync.Once + + // Touched by more than one goroutine, needs locking. + + nextSeqNoMut sync.Mutex + nextSeqNo sequenceNo + + inbufMut sync.Mutex + inbufCond *sync.Cond + inbuf bytes.Buffer + + expMut sync.Mutex + exp *time.Timer + + sendBuffer *sendBuffer // goroutine safe + + packetDelays [rttMeasureWindow]time.Duration + packetDelaysSlot int + packetDelaysMut sync.Mutex + + // Owned by the reader routine, needs no locking + + recvBuffer packetList + nextRecvSeqNo sequenceNo + lastAckedSeqNo sequenceNo + lastNegAckedSeqNo sequenceNo + expCount int + expReset time.Time + + // Only accessed atomically + + packetsIn int64 + packetsOut int64 + bytesIn int64 + bytesOut int64 + resentPackets int64 + droppedPackets int64 + outOfOrderPackets int64 + + // Special + + debugResetRecvSeqNo chan sequenceNo +} + +func newConn(m *Mux, dst net.Addr) *Conn { + conn := &Conn{ + mux: m, + dst: dst, + nextSeqNo: sequenceNo(rand.Uint32()), + packetSize: maxPacketSize, + in: make(chan packet, muxBufferPackets), + closed: make(chan struct{}), + sendBuffer: newSendBuffer(m), + exp: time.NewTimer(defExpTime), + debugResetRecvSeqNo: make(chan sequenceNo), + expReset: time.Now(), + } + + conn.lastAckedSeqNo = conn.nextSeqNo - 1 + conn.inbufCond = sync.NewCond(&conn.inbufMut) + + conn.cc = newWindowCC() + conn.sendBuffer.SetWindowAndRate(conn.cc.SendWindow(), conn.cc.PacketRate()) + conn.recvBuffer.Resize(128) + + return conn +} + +func (c *Conn) start() { + go c.reader() +} + +func (c *Conn) reader() { + if debugConnection { + log.Println(c, "reader() starting") + defer log.Println(c, "reader() exiting") + } + + for { + select { + case <-c.closed: + // Ack any received but not yet acked messages. + c.sendAck(0) + + // Send a shutdown message. + c.nextSeqNoMut.Lock() + c.mux.write(packet{ + src: c.connID, + dst: c.dst, + hdr: header{ + packetType: typeShutdown, + connID: c.remoteConnID, + sequenceNo: c.nextSeqNo, + }, + }) + c.nextSeqNo++ + c.nextSeqNoMut.Unlock() + atomic.AddInt64(&c.packetsOut, 1) + atomic.AddInt64(&c.bytesOut, dstHeaderLen) + return + + case pkt := <-c.in: + atomic.AddInt64(&c.packetsIn, 1) + atomic.AddInt64(&c.bytesIn, dstHeaderLen+int64(len(pkt.data))) + + c.expCount = 1 + + switch pkt.hdr.packetType { + case typeData: + c.rcvData(pkt) + case typeAck: + c.rcvAck(pkt) + case typeNegAck: + c.rcvNegAck(pkt) + case typeShutdown: + c.rcvShutdown(pkt) + default: + log.Println("Unhandled packet", pkt) + continue + } + + case <-c.exp.C: + c.eventExp() + c.resetExp() + + case n := <-c.debugResetRecvSeqNo: + // Back door for testing + c.lastAckedSeqNo = n - 1 + c.nextRecvSeqNo = n + } + } +} + +func (c *Conn) eventExp() { + c.expCount++ + + if c.sendBuffer.lost.Len() > 0 || c.sendBuffer.send.Len() > 0 { + c.cc.Exp() + c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate()) + c.sendBuffer.ScheduleResend() + + if debugConnection { + log.Println(c, "did resends due to Exp") + } + + if c.expCount > expCountClose && time.Since(c.expReset) > minTimeClose { + if debugConnection { + log.Println(c, "close due to Exp") + } + + // We're shutting down due to repeated exp:s. Don't wait for the + // send buffer to drain, which it would otherwise do in + // c.Close().. + c.sendBuffer.CrashStop() + + c.Close() + } + } +} + +func (c *Conn) rcvAck(pkt packet) { + ack := pkt.hdr.sequenceNo + + if debugConnection { + log.Printf("%v read Ack %v", c, ack) + } + + c.cc.Ack() + + if ack%rttMeasureSample == 0 { + if ts := timestamp(binary.BigEndian.Uint32(pkt.data)); ts > 0 { + if delay := time.Duration(timestampMicros()-ts) * time.Microsecond; delay > 0 { + c.packetDelaysMut.Lock() + c.packetDelays[c.packetDelaysSlot] = delay + c.packetDelaysSlot = (c.packetDelaysSlot + 1) % len(c.packetDelays) + c.packetDelaysMut.Unlock() + + if rtt, n := c.averageDelay(); n > 8 { + c.cc.UpdateRTT(rtt) + } + } + } + } + + c.sendBuffer.Acknowledge(ack) + c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate()) + + c.resetExp() +} + +func (c *Conn) averageDelay() (time.Duration, int) { + var total time.Duration + var n int + + c.packetDelaysMut.Lock() + for _, d := range c.packetDelays { + if d != 0 { + total += d + n++ + } + } + c.packetDelaysMut.Unlock() + + if n == 0 { + return 0, 0 + } + return total / time.Duration(n), n +} + +func (c *Conn) rcvNegAck(pkt packet) { + nak := pkt.hdr.sequenceNo + + if debugConnection { + log.Printf("%v read NegAck %v", c, nak) + } + + c.sendBuffer.NegativeAck(nak) + + //c.cc.NegAck() + c.resetExp() +} + +func (c *Conn) rcvShutdown(pkt packet) { + // XXX: We accept shutdown packets somewhat from the future since the + // sender will number the shutdown after any packets that might still be + // in the write buffer. This should be fixed to let the write buffer empty + // on close and reduce the window here. + if pkt.LessSeq(c.nextRecvSeqNo + 128) { + if debugConnection { + log.Println(c, "close due to shutdown") + } + c.Close() + } +} + +func (c *Conn) rcvData(pkt packet) { + if debugConnection { + log.Println(c, "recv data", pkt.hdr) + } + + if pkt.LessSeq(c.nextRecvSeqNo) { + if debugConnection { + log.Printf("%v old packet received; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo) + } + atomic.AddInt64(&c.droppedPackets, 1) + return + } + + if debugConnection { + log.Println(c, "into recv buffer:", pkt) + } + c.recvBuffer.InsertSorted(pkt) + if c.recvBuffer.LowestSeq() == c.nextRecvSeqNo { + for _, pkt := range c.recvBuffer.PopSequence(^sequenceNo(0)) { + if debugConnection { + log.Println(c, "from recv buffer:", pkt) + } + + // An in-sequence packet. + + c.nextRecvSeqNo = pkt.hdr.sequenceNo + 1 + + c.sendAck(pkt.hdr.timestamp) + + c.inbufMut.Lock() + for c.inbuf.Len() > len(pkt.data)+maxInputBuffer { + c.inbufCond.Wait() + select { + case <-c.closed: + return + default: + } + } + + c.inbuf.Write(pkt.data) + c.inbufCond.Broadcast() + c.inbufMut.Unlock() + } + } else { + if debugConnection { + log.Printf("%v lost; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo) + } + c.recvBuffer.InsertSorted(pkt) + c.sendNegAck() + atomic.AddInt64(&c.outOfOrderPackets, 1) + } +} + +func (c *Conn) sendAck(ts timestamp) { + if c.lastAckedSeqNo == c.nextRecvSeqNo { + return + } + + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], uint32(ts)) + c.mux.write(packet{ + src: c.connID, + dst: c.dst, + hdr: header{ + packetType: typeAck, + connID: c.remoteConnID, + sequenceNo: c.nextRecvSeqNo, + }, + data: buf[:], + }) + + atomic.AddInt64(&c.packetsOut, 1) + atomic.AddInt64(&c.bytesOut, dstHeaderLen) + if debugConnection { + log.Printf("%v send Ack %v", c, c.nextRecvSeqNo) + } + + c.lastAckedSeqNo = c.nextRecvSeqNo +} + +func (c *Conn) sendNegAck() { + if c.lastNegAckedSeqNo == c.nextRecvSeqNo { + return + } + + c.mux.write(packet{ + src: c.connID, + dst: c.dst, + hdr: header{ + packetType: typeNegAck, + connID: c.remoteConnID, + sequenceNo: c.nextRecvSeqNo, + }, + }) + + atomic.AddInt64(&c.packetsOut, 1) + atomic.AddInt64(&c.bytesOut, dstHeaderLen) + if debugConnection { + log.Printf("%v send NegAck %v", c, c.nextRecvSeqNo) + } + + c.lastNegAckedSeqNo = c.nextRecvSeqNo +} + +func (c *Conn) resetExp() { + d, _ := c.averageDelay() + d = d*4 + 10*time.Millisecond + + if d < defExpTime { + d = defExpTime + } + + c.expMut.Lock() + c.exp.Reset(d) + c.expMut.Unlock() +} + +// String returns a string representation of the connection. +func (c *Conn) String() string { + return fmt.Sprintf("%v/%v/%v", c.connID, c.LocalAddr(), c.RemoteAddr()) +} + +// Read reads data from the connection. +// Read can be made to time out and return a Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetReadDeadline. +func (c *Conn) Read(b []byte) (n int, err error) { + c.inbufMut.Lock() + defer c.inbufMut.Unlock() + for c.inbuf.Len() == 0 { + select { + case <-c.closed: + return 0, io.EOF + default: + } + c.inbufCond.Wait() + } + return c.inbuf.Read(b) +} + +// Write writes data to the connection. +// Write can be made to time out and return a Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetWriteDeadline. +func (c *Conn) Write(b []byte) (n int, err error) { + select { + case <-c.closed: + return 0, ErrClosedConn + default: + } + + sent := 0 + sliceSize := c.packetSize - sliceOverhead + for i := 0; i < len(b); i += sliceSize { + nxt := i + sliceSize + if nxt > len(b) { + nxt = len(b) + } + slice := b[i:nxt] + sliceCopy := c.mux.buffers.Get().([]byte)[:len(slice)] + copy(sliceCopy, slice) + + c.nextSeqNoMut.Lock() + pkt := packet{ + src: c.connID, + dst: c.dst, + hdr: header{ + packetType: typeData, + sequenceNo: c.nextSeqNo, + connID: c.remoteConnID, + }, + data: sliceCopy, + } + c.nextSeqNo++ + c.nextSeqNoMut.Unlock() + + if err := c.sendBuffer.Write(pkt); err != nil { + return sent, err + } + + atomic.AddInt64(&c.packetsOut, 1) + atomic.AddInt64(&c.bytesOut, int64(len(slice)+dstHeaderLen)) + + sent += len(slice) + c.resetExp() + } + return sent, nil +} + +// Close closes the connection. +// Any blocked Read or Write operations will be unblocked and return errors. +func (c *Conn) Close() error { + c.closeOnce.Do(func() { + if debugConnection { + log.Println(c, "explicit close start") + defer log.Println(c, "explicit close done") + } + + // XXX: Ugly hack to implement lingering sockets... + time.Sleep(4 * defExpTime) + + c.sendBuffer.Stop() + c.mux.removeConn(c) + close(c.closed) + + c.inbufMut.Lock() + c.inbufCond.Broadcast() + c.inbufMut.Unlock() + }) + return nil +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.mux.Addr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.dst +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. It is equivalent to calling both +// SetReadDeadline and SetWriteDeadline. +// +// A deadline is an absolute time after which I/O operations +// fail with a timeout (see type Error) instead of +// blocking. The deadline applies to all future I/O, not just +// the immediately following call to Read or Write. +// +// An idle timeout can be implemented by repeatedly extending +// the deadline after successful Read or Write calls. +// +// A zero value for t means I/O operations will not time out. +// +// BUG(jb): SetDeadline is not implemented. +func (c *Conn) SetDeadline(t time.Time) error { + return ErrNotImplemented +} + +// SetReadDeadline sets the deadline for future Read calls. +// A zero value for t means Read will not time out. +// +// BUG(jb): SetReadDeadline is not implemented. +func (c *Conn) SetReadDeadline(t time.Time) error { + return ErrNotImplemented +} + +// SetWriteDeadline sets the deadline for future Write calls. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means Write will not time out. +// +// BUG(jb): SetWriteDeadline is not implemented. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return ErrNotImplemented +} + +type Statistics struct { + DataPacketsIn int64 + DataPacketsOut int64 + DataBytesIn int64 + DataBytesOut int64 + ResentPackets int64 + DroppedPackets int64 + OutOfOrderPackets int64 +} + +// String returns a printable represetnation of the Statistics. +func (s Statistics) String() string { + return fmt.Sprintf("PktsIn: %d, PktsOut: %d, BytesIn: %d, BytesOut: %d, PktsResent: %d, PktsDropped: %d, PktsOutOfOrder: %d", + s.DataPacketsIn, s.DataPacketsOut, s.DataBytesIn, s.DataBytesOut, s.ResentPackets, s.DroppedPackets, s.OutOfOrderPackets) +} + +// GetStatistics returns a snapsht of the current connection statistics. +func (c *Conn) GetStatistics() Statistics { + return Statistics{ + DataPacketsIn: atomic.LoadInt64(&c.packetsIn), + DataPacketsOut: atomic.LoadInt64(&c.packetsOut), + DataBytesIn: atomic.LoadInt64(&c.bytesIn), + DataBytesOut: atomic.LoadInt64(&c.bytesOut), + ResentPackets: atomic.LoadInt64(&c.resentPackets), + DroppedPackets: atomic.LoadInt64(&c.droppedPackets), + OutOfOrderPackets: atomic.LoadInt64(&c.outOfOrderPackets), + } +} diff --git a/core/dst/cookie.go b/core/dst/cookie.go new file mode 100644 index 0000000..0daeb8b --- /dev/null +++ b/core/dst/cookie.go @@ -0,0 +1,29 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "net" +) + +var cookieKey = make([]byte, 16) + +func init() { + _, err := rand.Reader.Read(cookieKey) + if err != nil { + panic(err) + } +} + +func cookie(remote net.Addr) uint32 { + hash := sha256.New() + hash.Write([]byte(remote.String())) + hash.Write(cookieKey) + bs := hash.Sum(nil) + return binary.BigEndian.Uint32(bs) +} diff --git a/core/dst/debug.go b/core/dst/debug.go new file mode 100644 index 0000000..ed9f3fe --- /dev/null +++ b/core/dst/debug.go @@ -0,0 +1,26 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + "os" + "strings" +) + +var ( + debugConnection bool + debugMux bool + debugCC bool +) + +func init() { + debug := make(map[string]bool) + for _, s := range strings.Split(os.Getenv("DSTDEBUG"), ",") { + debug[strings.TrimSpace(s)] = true + } + debugConnection = debug["conn"] + debugMux = debug["mux"] + debugCC = debug["cc"] +} diff --git a/core/dst/doc.go b/core/dst/doc.go new file mode 100644 index 0000000..fe4c768 --- /dev/null +++ b/core/dst/doc.go @@ -0,0 +1,12 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +/* + +Package dst implements the Datagram Stream Transfer protocol. + +DST is a way to get reliable stream connections (like TCP) on top of UDP. + +*/ +package dst diff --git a/core/dst/errors.go b/core/dst/errors.go new file mode 100644 index 0000000..dd4eb0e --- /dev/null +++ b/core/dst/errors.go @@ -0,0 +1,23 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +// Error represents the various dst-internal error conditions. +type Error struct { + Err string +} + +// Error returns a string representation of the error. +func (e Error) Error() string { + return e.Err +} + +var ( + ErrClosedConn = &Error{"operation on closed connection"} + ErrClosedMux = &Error{"operation on closed mux"} + ErrHandshakeTimeout = &Error{"handshake timeout"} + ErrNotDST = &Error{"network is not dst"} + ErrNotImplemented = &Error{"not implemented"} +) diff --git a/core/dst/mux.go b/core/dst/mux.go new file mode 100644 index 0000000..8765843 --- /dev/null +++ b/core/dst/mux.go @@ -0,0 +1,422 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + "fmt" + + "net" + "sync" + "time" +) + +const ( + maxIncomingRequests = 1024 + maxPacketSize = 500 + handshakeTimeout = 5 * time.Second + handshakeInterval = 1 * time.Second +) + +// Mux is a UDP multiplexer of DST connections. +type Mux struct { + conn net.PacketConn + packetSize int + + conns map[connectionID]*Conn + handshakes map[connectionID]chan packet + connsMut sync.Mutex + + incoming chan *Conn + closed chan struct{} + closeOnce sync.Once + + buffers *sync.Pool +} + +// NewMux creates a new DST Mux on top of a packet connection. +func NewMux(conn net.PacketConn, packetSize int) *Mux { + if packetSize <= 0 { + packetSize = maxPacketSize + } + m := &Mux{ + conn: conn, + packetSize: packetSize, + conns: map[connectionID]*Conn{}, + handshakes: make(map[connectionID]chan packet), + incoming: make(chan *Conn, maxIncomingRequests), + closed: make(chan struct{}), + buffers: &sync.Pool{ + New: func() interface{} { + return make([]byte, packetSize) + }, + }, + } + + // Attempt to maximize buffer space. Start at 16 MB and work downwards 0.5 + // MB at a time. + + if conn, ok := conn.(*net.UDPConn); ok { + for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 { + err := conn.SetReadBuffer(buf) + if err == nil { + if debugMux { + log.Println(m, "read buffer is", buf) + } + break + } + } + for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 { + err := conn.SetWriteBuffer(buf) + if err == nil { + if debugMux { + log.Println(m, "write buffer is", buf) + } + break + } + } + } + + go m.readerLoop() + + return m +} + +// Accept waits for and returns the next connection to the listener. +func (m *Mux) Accept() (net.Conn, error) { + return m.AcceptDST() +} + +// AcceptDST waits for and returns the next connection to the listener. +func (m *Mux) AcceptDST() (*Conn, error) { + conn, ok := <-m.incoming + if !ok { + return nil, ErrClosedMux + } + return conn, nil +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (m *Mux) Close() error { + var err error = ErrClosedMux + m.closeOnce.Do(func() { + err = m.conn.Close() + close(m.incoming) + close(m.closed) + }) + return err +} + +// Addr returns the listener's network address. +func (m *Mux) Addr() net.Addr { + return m.conn.LocalAddr() +} + +// Dial connects to the address on the named network. +// +// Network must be "dst". +// +// Addresses have the form host:port. If host is a literal IPv6 address or +// host name, it must be enclosed in square brackets as in "[::1]:80", +// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and +// SplitHostPort manipulate addresses in this form. +// +// Examples: +// Dial("dst", "12.34.56.78:80") +// Dial("dst", "google.com:http") +// Dial("dst", "[2001:db8::1]:http") +// Dial("dst", "[fe80::1%lo0]:80") +func (m *Mux) Dial(network, addr string) (net.Conn, error) { + return m.DialDST(network, addr) +} + +// Dial connects to the address on the named network. +// +// Network must be "dst". +// +// Addresses have the form host:port. If host is a literal IPv6 address or +// host name, it must be enclosed in square brackets as in "[::1]:80", +// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and +// SplitHostPort manipulate addresses in this form. +// +// Examples: +// Dial("dst", "12.34.56.78:80") +// Dial("dst", "google.com:http") +// Dial("dst", "[2001:db8::1]:http") +// Dial("dst", "[fe80::1%lo0]:80") +func (m *Mux) DialDST(network, addr string) (*Conn, error) { + if network != "dst" { + return nil, ErrNotDST + } + + dst, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + resp := make(chan packet) + + m.connsMut.Lock() + connID := m.newConnID() + m.handshakes[connID] = resp + m.connsMut.Unlock() + + conn, err := m.clientHandshake(dst, connID, resp) + + m.connsMut.Lock() + defer m.connsMut.Unlock() + delete(m.handshakes, connID) + + if err != nil { + return nil, err + } + + m.conns[connID] = conn + return conn, nil +} + +// handshake performs the client side handshake (i.e. Dial) +func (m *Mux) clientHandshake(dst net.Addr, connID connectionID, resp chan packet) (*Conn, error) { + if debugMux { + log.Printf("%v dial %v connID %v", m, dst, connID) + } + + nextHandshake := time.NewTimer(0) + defer nextHandshake.Stop() + + handshakeTimeout := time.NewTimer(handshakeTimeout) + defer handshakeTimeout.Stop() + + var remoteCookie uint32 + seqNo := randomSeqNo() + + for { + select { + case <-m.closed: + // Failure. The mux has been closed. + return nil, ErrClosedConn + + case <-handshakeTimeout.C: + // Handshake timeout. Close and abort. + return nil, ErrHandshakeTimeout + + case <-nextHandshake.C: + // Send a handshake request. + + m.write(packet{ + src: connID, + dst: dst, + hdr: header{ + packetType: typeHandshake, + flags: flagRequest, + connID: 0, + sequenceNo: seqNo, + timestamp: timestampMicros(), + }, + data: handshakeData{uint32(m.packetSize), connID, remoteCookie}.marshal(), + }) + nextHandshake.Reset(handshakeInterval) + + case pkt := <-resp: + hd := unmarshalHandshakeData(pkt.data) + + if pkt.hdr.flags&flagCookie == flagCookie { + // We should resend the handshake request with a different cookie value. + remoteCookie = hd.cookie + nextHandshake.Reset(0) + } else if pkt.hdr.flags&flagResponse == flagResponse { + // Successfull handshake response. + conn := newConn(m, dst) + + conn.connID = connID + conn.remoteConnID = hd.connID + conn.nextRecvSeqNo = pkt.hdr.sequenceNo + 1 + conn.packetSize = int(hd.packetSize) + if conn.packetSize > m.packetSize { + conn.packetSize = m.packetSize + } + + conn.nextSeqNo = seqNo + 1 + + conn.start() + + return conn, nil + } + } + } +} + +func (m *Mux) readerLoop() { + buf := make([]byte, m.packetSize) + for { + buf = buf[:cap(buf)] + n, from, err := m.conn.ReadFrom(buf) + if err != nil { + m.Close() + return + } + buf = buf[:n] + + hdr := unmarshalHeader(buf) + + var bufCopy []byte + if len(buf) > dstHeaderLen { + bufCopy = m.buffers.Get().([]byte)[:len(buf)-dstHeaderLen] + copy(bufCopy, buf[dstHeaderLen:]) + } + + pkt := packet{hdr: hdr, data: bufCopy} + if debugMux { + log.Println(m, "read", pkt) + } + + if hdr.packetType == typeHandshake { + m.incomingHandshake(from, hdr, bufCopy) + } else { + m.connsMut.Lock() + conn, ok := m.conns[hdr.connID] + m.connsMut.Unlock() + + if ok { + conn.in <- packet{ + dst: nil, + hdr: hdr, + data: bufCopy, + } + } else if debugMux && hdr.packetType != typeShutdown { + log.Printf("packet %v for unknown conn %v", hdr, hdr.connID) + } + } + } +} + +func (m *Mux) incomingHandshake(from net.Addr, hdr header, data []byte) { + if hdr.connID == 0 { + // A new incoming handshake request. + m.incomingHandshakeRequest(from, hdr, data) + } else { + // A response to an ongoing handshake. + m.incomingHandshakeResponse(from, hdr, data) + } +} + +func (m *Mux) incomingHandshakeRequest(from net.Addr, hdr header, data []byte) { + if hdr.flags&flagRequest != flagRequest { + log.Printf("Handshake pattern with flags 0x%x to connID zero", hdr.flags) + return + } + + hd := unmarshalHandshakeData(data) + + correctCookie := cookie(from) + if hd.cookie != correctCookie { + // Incorrect or missing SYN cookie. Send back a handshake + // with the expected one. + m.write(packet{ + dst: from, + hdr: header{ + packetType: typeHandshake, + flags: flagResponse | flagCookie, + connID: hd.connID, + timestamp: timestampMicros(), + }, + data: handshakeData{ + packetSize: uint32(m.packetSize), + cookie: correctCookie, + }.marshal(), + }) + return + } + + seqNo := randomSeqNo() + + m.connsMut.Lock() + connID := m.newConnID() + + conn := newConn(m, from) + conn.connID = connID + conn.remoteConnID = hd.connID + conn.nextSeqNo = seqNo + 1 + conn.nextRecvSeqNo = hdr.sequenceNo + 1 + conn.packetSize = int(hd.packetSize) + if conn.packetSize > m.packetSize { + conn.packetSize = m.packetSize + } + conn.start() + + m.conns[connID] = conn + m.connsMut.Unlock() + + m.write(packet{ + dst: from, + hdr: header{ + packetType: typeHandshake, + flags: flagResponse, + connID: hd.connID, + sequenceNo: seqNo, + timestamp: timestampMicros(), + }, + data: handshakeData{ + connID: conn.connID, + packetSize: uint32(conn.packetSize), + }.marshal(), + }) + + m.incoming <- conn +} + +func (m *Mux) incomingHandshakeResponse(from net.Addr, hdr header, data []byte) { + m.connsMut.Lock() + handShake, ok := m.handshakes[hdr.connID] + m.connsMut.Unlock() + + if ok { + // This is a response to a handshake in progress. + handShake <- packet{ + dst: nil, + hdr: hdr, + data: data, + } + } else if debugMux && hdr.packetType != typeShutdown { + log.Printf("Handshake packet %v for unknown conn %v", hdr, hdr.connID) + } +} + +func (m *Mux) write(pkt packet) (int, error) { + buf := m.buffers.Get().([]byte) + buf = buf[:dstHeaderLen+len(pkt.data)] + pkt.hdr.marshal(buf) + copy(buf[dstHeaderLen:], pkt.data) + if debugMux { + log.Println(m, "write", pkt) + } + n, err := m.conn.WriteTo(buf, pkt.dst) + m.buffers.Put(buf) + return n, err +} + +func (m *Mux) String() string { + return fmt.Sprintf("Mux-%v", m.Addr()) +} + +// Find a unique connection ID +func (m *Mux) newConnID() connectionID { + for { + connID := randomConnID() + if _, ok := m.conns[connID]; ok { + continue + } + if _, ok := m.handshakes[connID]; ok { + continue + } + return connID + } +} + +func (m *Mux) removeConn(c *Conn) { + m.connsMut.Lock() + delete(m.conns, c.connID) + m.connsMut.Unlock() +} diff --git a/core/dst/packetlist.go b/core/dst/packetlist.go new file mode 100644 index 0000000..6e5ce9a --- /dev/null +++ b/core/dst/packetlist.go @@ -0,0 +1,119 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +type packetList struct { + packets []packet + slot int +} + +// CutLessSeq cuts packets from the start of the list with sequence numbers +// lower than seq. Returns the number of packets that were cut. +func (l *packetList) CutLessSeq(seq sequenceNo) int { + var i, cut int + for i = range l.packets { + if i == l.slot { + break + } + if !l.packets[i].LessSeq(seq) { + break + } + cut++ + } + if cut > 0 { + l.Cut(cut) + } + return cut +} + +func (l *packetList) Cut(n int) { + copy(l.packets, l.packets[n:]) + l.slot -= n +} + +func (l *packetList) Full() bool { + return l.slot == len(l.packets) +} + +func (l *packetList) All() []packet { + return l.packets[:l.slot] +} + +func (l *packetList) Append(pkt packet) bool { + if l.slot == len(l.packets) { + return false + } + l.packets[l.slot] = pkt + l.slot++ + return true +} + +func (l *packetList) AppendAll(pkts []packet) { + l.packets = append(l.packets[:l.slot], pkts...) + l.slot += len(pkts) +} + +func (l *packetList) Cap() int { + return len(l.packets) +} + +func (l *packetList) Len() int { + return l.slot +} + +func (l *packetList) Resize(s int) { + if s <= cap(l.packets) { + l.packets = l.packets[:s] + } else { + t := make([]packet, s) + copy(t, l.packets) + l.packets = t + } +} + +func (l *packetList) InsertSorted(pkt packet) { + for i := range l.packets { + if i >= l.slot { + l.packets[i] = pkt + l.slot++ + return + } + if pkt.hdr.sequenceNo == l.packets[i].hdr.sequenceNo { + return + } + if pkt.Less(l.packets[i]) { + copy(l.packets[i+1:], l.packets[i:]) + l.packets[i] = pkt + if l.slot < len(l.packets) { + l.slot++ + } + return + } + } +} + +func (l *packetList) LowestSeq() sequenceNo { + return l.packets[0].hdr.sequenceNo +} + +func (l *packetList) PopSequence(maxSeq sequenceNo) []packet { + highSeq := l.packets[0].hdr.sequenceNo + if highSeq >= maxSeq { + return nil + } + + var i int + for i = 1; i < l.slot; i++ { + seq := l.packets[i].hdr.sequenceNo + if seq != highSeq+1 || seq >= maxSeq { + break + } + highSeq++ + } + pkts := make([]packet, i) + copy(pkts, l.packets[:i]) + l.Cut(i) + return pkts +} diff --git a/core/dst/packets.go b/core/dst/packets.go new file mode 100644 index 0000000..c44b8cd --- /dev/null +++ b/core/dst/packets.go @@ -0,0 +1,155 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + "encoding/binary" + "fmt" + "net" +) + +const dstHeaderLen = 12 + +type packetType int8 + +const ( + typeHandshake packetType = 0x0 + typeData = 0x1 + typeAck = 0x2 + typeNegAck = 0x3 + typeShutdown = 0x4 +) + +func (t packetType) String() string { + switch t { + case typeData: + return "data" + case typeHandshake: + return "handshake" + case typeAck: + return "ack" + case typeNegAck: + return "negAck" + case typeShutdown: + return "shutdown" + default: + return "unknown" + } +} + +type connectionID uint32 + +func (c connectionID) String() string { + return fmt.Sprintf("Ci%08x", uint32(c)) +} + +type sequenceNo uint32 + +func (s sequenceNo) String() string { + return fmt.Sprintf("Sq%d", uint32(s)) +} + +type timestamp uint32 + +func (t timestamp) String() string { + return fmt.Sprintf("Ts%d", uint32(t)) +} + +const ( + flagRequest = 1 << 0 // This packet is a handshake request + flagResponse = 1 << 1 // This packet is a handshake response + flagCookie = 1 << 2 // This packet contains a coookie challenge +) + +type header struct { + packetType packetType // 4 bits + flags uint8 // 4 bits + connID connectionID // 24 bits + sequenceNo sequenceNo + timestamp timestamp +} + +func (h header) marshal(bs []byte) { + binary.BigEndian.PutUint32(bs, uint32(h.connID&0xffffff)) + bs[0] = h.flags | uint8(h.packetType)<<4 + binary.BigEndian.PutUint32(bs[4:], uint32(h.sequenceNo)) + binary.BigEndian.PutUint32(bs[8:], uint32(h.timestamp)) +} + +func unmarshalHeader(bs []byte) header { + var h header + h.packetType = packetType(bs[0] >> 4) + h.flags = bs[0] & 0xf + h.connID = connectionID(binary.BigEndian.Uint32(bs) & 0xffffff) + h.sequenceNo = sequenceNo(binary.BigEndian.Uint32(bs[4:])) + h.timestamp = timestamp(binary.BigEndian.Uint32(bs[8:])) + return h +} + +func (h header) String() string { + return fmt.Sprintf("header{type=%s flags=0x%x connID=%v seq=%v time=%v}", h.packetType, h.flags, h.connID, h.sequenceNo, h.timestamp) +} + +type handshakeData struct { + packetSize uint32 + connID connectionID + cookie uint32 +} + +func (h handshakeData) marshalInto(data []byte) { + binary.BigEndian.PutUint32(data[0:], h.packetSize) + binary.BigEndian.PutUint32(data[4:], uint32(h.connID)) + binary.BigEndian.PutUint32(data[8:], h.cookie) +} + +func (h handshakeData) marshal() []byte { + var data [12]byte + h.marshalInto(data[:]) + return data[:] +} + +func unmarshalHandshakeData(data []byte) handshakeData { + var h handshakeData + h.packetSize = binary.BigEndian.Uint32(data[0:]) + h.connID = connectionID(binary.BigEndian.Uint32(data[4:])) + h.cookie = binary.BigEndian.Uint32(data[8:]) + return h +} + +func (h handshakeData) String() string { + return fmt.Sprintf("handshake{size=%d connID=%v cookie=0x%08x}", h.packetSize, h.connID, h.cookie) +} + +type packet struct { + src connectionID + dst net.Addr + hdr header + data []byte +} + +func (p packet) String() string { + var dst string + if p.dst != nil { + dst = "dst=" + p.dst.String() + " " + } + switch p.hdr.packetType { + case typeHandshake: + return fmt.Sprintf("%spacket{src=%v %v %v}", dst, p.src, p.hdr, unmarshalHandshakeData(p.data)) + default: + return fmt.Sprintf("%spacket{src=%v %v data[:%d]}", dst, p.src, p.hdr, len(p.data)) + } +} + +func (p packet) LessSeq(seq sequenceNo) bool { + diff := seq - p.hdr.sequenceNo + if diff == 0 { + return false + } + return diff < 1<<31 +} + +func (a packet) Less(b packet) bool { + return a.LessSeq(b.hdr.sequenceNo) +} diff --git a/core/dst/sendbuffer.go b/core/dst/sendbuffer.go new file mode 100644 index 0000000..786b15d --- /dev/null +++ b/core/dst/sendbuffer.go @@ -0,0 +1,260 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + "fmt" + + "sync" + + "github.com/juju/ratelimit" +) + +/* + sendWindow + v + [S|S|S|S|Q|Q|Q|Q| | | | | | | | | ] + ^ ^writeSlot + sendSlot +*/ +type sendBuffer struct { + mux *Mux // we send packets here + scheduler *ratelimit.Bucket // sets send rate for packets + + sendWindow int // maximum number of outstanding non-acked packets + packetRate int // target pps + + send packetList // buffered packets + sendSlot int // buffer slot from which to send next packet + + lost packetList // list of packets reported lost by timeout + lostSlot int // next lost packet to resend + + closed bool + closing bool + mut sync.Mutex + cond *sync.Cond +} + +const ( + schedulerRate = 1e6 + schedulerCapacity = schedulerRate / 40 +) + +// newSendBuffer creates a new send buffer with a zero window. +// SetRateAndWindow() must be called to set an initial packet rate and send +// window before using. +func newSendBuffer(m *Mux) *sendBuffer { + b := &sendBuffer{ + mux: m, + scheduler: ratelimit.NewBucketWithRate(schedulerRate, schedulerCapacity), + } + b.cond = sync.NewCond(&b.mut) + go b.writerLoop() + return b +} + +// Write puts a new packet in send buffer and schedules a send. Blocks when +// the window size is or would be exceeded. +func (b *sendBuffer) Write(pkt packet) error { + b.mut.Lock() + defer b.mut.Unlock() + + for b.send.Full() || b.send.Len() >= b.sendWindow { + if b.closing { + return ErrClosedConn + } + if debugConnection { + log.Println(b, "Write blocked") + } + b.cond.Wait() + } + if !b.send.Append(pkt) { + panic("bug: append failed") + } + b.cond.Broadcast() + return nil +} + +// Acknowledge removes packets with lower sequence numbers from the loss list +// or send buffer. +func (b *sendBuffer) Acknowledge(seq sequenceNo) { + b.mut.Lock() + + if cut := b.lost.CutLessSeq(seq); cut > 0 { + if debugConnection { + log.Println(b, "cut", cut, "from loss list") + } + // Next resend should always start with the first packet, regardless + // of what we might already have resent previously. + b.lostSlot = 0 + b.cond.Broadcast() + } + + if cut := b.send.CutLessSeq(seq); cut > 0 { + if debugConnection { + log.Println(b, "cut", cut, "from send list") + } + b.sendSlot -= cut + b.cond.Broadcast() + } + + b.mut.Unlock() +} + +func (b *sendBuffer) NegativeAck(seq sequenceNo) { + b.mut.Lock() + + pkts := b.send.PopSequence(seq) + if cut := len(pkts); cut > 0 { + b.lost.AppendAll(pkts) + if debugConnection { + log.Println(b, "cut", cut, "from send list, adding to loss list") + log.Println(seq, pkts) + } + b.sendSlot -= cut + b.lostSlot = 0 + b.cond.Broadcast() + } + + b.mut.Unlock() +} + +// ScheduleResend arranges for a resend of all currently unacknowledged +// packets. +func (b *sendBuffer) ScheduleResend() { + b.mut.Lock() + + if b.sendSlot > 0 { + // There are packets that have been sent but not acked. Move them from + // the send buffer to the loss list for retransmission. + if debugConnection { + log.Println(b, "scheduled resend from send list", b.sendSlot) + } + + // Append the packets to the loss list and rewind the send buffer + b.lost.AppendAll(b.send.All()[:b.sendSlot]) + b.send.Cut(b.sendSlot) + b.sendSlot = 0 + b.cond.Broadcast() + } + + if b.lostSlot > 0 { + // Also resend whatever was already in the loss list + if debugConnection { + log.Println(b, "scheduled resend from loss list", b.lostSlot) + } + b.lostSlot = 0 + b.cond.Broadcast() + } + + b.mut.Unlock() +} + +// SetWindowAndRate sets the window size (in packets) and packet rate (in +// packets per second) to use when sending. +func (b *sendBuffer) SetWindowAndRate(sendWindow, packetRate int) { + b.mut.Lock() + if debugConnection { + log.Println(b, "new window & rate", sendWindow, packetRate) + } + b.packetRate = packetRate + b.sendWindow = sendWindow + if b.sendWindow > b.send.Cap() { + b.send.Resize(b.sendWindow) + b.cond.Broadcast() + } + b.mut.Unlock() +} + +// Stop stops the send buffer from any doing further sending, but waits for +// the current buffers to be drained. +func (b *sendBuffer) Stop() { + b.mut.Lock() + + if b.closed || b.closing { + return + } + + b.closing = true + for b.lost.Len() > 0 || b.send.Len() > 0 { + b.cond.Wait() + } + + b.closed = true + b.cond.Broadcast() + b.mut.Unlock() +} + +// CrashStop stops the send buffer from any doing further sending, without +// waiting for buffers to drain. +func (b *sendBuffer) CrashStop() { + b.mut.Lock() + + if b.closed || b.closing { + return + } + + b.closing = true + b.closed = true + b.cond.Broadcast() + b.mut.Unlock() +} + +func (b *sendBuffer) String() string { + return fmt.Sprintf("sendBuffer@%p", b) +} + +func (b *sendBuffer) writerLoop() { + if debugConnection { + log.Println(b, "writer() starting") + defer log.Println(b, "writer() exiting") + } + + b.scheduler.Take(schedulerCapacity) + for { + var pkt packet + b.mut.Lock() + for b.lostSlot >= b.sendWindow || + (b.sendSlot == b.send.Len() && b.lostSlot == b.lost.Len()) { + if b.closed { + b.mut.Unlock() + return + } + + if debugConnection { + log.Println(b, "writer() paused", b.lostSlot, b.sendSlot, b.sendWindow, b.lost.Len()) + } + b.cond.Wait() + } + + if b.lostSlot < b.lost.Len() { + pkt = b.lost.All()[b.lostSlot] + pkt.hdr.timestamp = timestampMicros() + b.lostSlot++ + + if debugConnection { + log.Println(b, "resend", b.lostSlot, b.lost.Len(), b.sendWindow, pkt.hdr.connID, pkt.hdr.sequenceNo) + } + } else if b.sendSlot < b.send.Len() { + pkt = b.send.All()[b.sendSlot] + pkt.hdr.timestamp = timestampMicros() + b.sendSlot++ + + if debugConnection { + log.Println(b, "send", b.sendSlot, b.send.Len(), b.sendWindow, pkt.hdr.connID, pkt.hdr.sequenceNo) + } + } + + b.cond.Broadcast() + packetRate := b.packetRate + b.mut.Unlock() + + if pkt.dst != nil { + b.scheduler.Wait(schedulerRate / int64(packetRate)) + b.mux.write(pkt) + } + } +} diff --git a/core/dst/util.go b/core/dst/util.go new file mode 100644 index 0000000..cdd6d91 --- /dev/null +++ b/core/dst/util.go @@ -0,0 +1,29 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + logger "log" + "math/rand" + "os" + "time" +) + +var log = logger.New(os.Stderr, "", logger.LstdFlags) + +func SetLogger(l *logger.Logger) { + log = l +} +func timestampMicros() timestamp { + return timestamp(time.Now().UnixNano() / 1000) +} + +func randomSeqNo() sequenceNo { + return sequenceNo(rand.Uint32()) +} + +func randomConnID() connectionID { + return connectionID(rand.Uint32() & 0xffffff) +} diff --git a/core/dst/windowcc.go b/core/dst/windowcc.go new file mode 100644 index 0000000..cc2dd69 --- /dev/null +++ b/core/dst/windowcc.go @@ -0,0 +1,144 @@ +// Copyright 2014 The DST Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package dst + +import ( + "fmt" + "io" + "os" + "time" +) + +type windowCC struct { + minWindow int + maxWindow int + currentWindow int + minRate int + maxRate int + currentRate int + targetRate int + + curRTT time.Duration + minRTT time.Duration + + statsFile io.WriteCloser + start time.Time +} + +func newWindowCC() *windowCC { + var statsFile io.WriteCloser + + if debugCC { + statsFile, _ = os.Create(fmt.Sprintf("cc-log-%d.csv", time.Now().Unix())) + fmt.Fprintf(statsFile, "ms,minWin,maxWin,curWin,minRate,maxRate,curRate,minRTT,curRTT\n") + } + + return &windowCC{ + minWindow: 1, // Packets + maxWindow: 16 << 10, + currentWindow: 1, + + minRate: 100, // PPS + maxRate: 80e3, // Roughly 1 Gbps at 1500 bytes per packet + currentRate: 100, + targetRate: 1000, + + minRTT: 10 * time.Second, + statsFile: statsFile, + start: time.Now(), + } +} + +func (w *windowCC) Ack() { + if w.curRTT > w.minRTT+100*time.Millisecond { + return + } + + changed := false + + if w.currentWindow < w.maxWindow { + w.currentWindow++ + changed = true + } + + if w.currentRate != w.targetRate { + w.currentRate = (w.currentRate*7 + w.targetRate) / 8 + changed = true + } + + if changed && debugCC { + w.log() + log.Println("Ack", w.currentWindow, w.currentRate) + } +} + +func (w *windowCC) NegAck() { + if w.currentWindow > w.minWindow { + w.currentWindow /= 2 + } + if w.currentRate > w.minRate { + w.currentRate /= 2 + } + if debugCC { + w.log() + log.Println("NegAck", w.currentWindow, w.currentRate) + } +} + +func (w *windowCC) Exp() { + w.currentWindow = w.minWindow + if debugCC { + w.log() + log.Println("Exp", w.currentWindow, w.currentRate) + } +} + +func (w *windowCC) SendWindow() int { + if w.currentWindow < w.minWindow { + return w.minWindow + } + if w.currentWindow > w.maxWindow { + return w.maxWindow + } + return w.currentWindow +} + +func (w *windowCC) PacketRate() int { + if w.currentRate < w.minRate { + return w.minRate + } + if w.currentRate > w.maxRate { + return w.maxRate + } + return w.currentRate +} + +func (w *windowCC) UpdateRTT(rtt time.Duration) { + w.curRTT = rtt + if w.curRTT < w.minRTT { + w.minRTT = w.curRTT + if debugCC { + log.Println("Min RTT", w.minRTT) + } + } + + if w.curRTT > w.minRTT+200*time.Millisecond && w.targetRate > 2*w.minRate { + w.targetRate -= w.minRate + } else if w.curRTT < w.minRTT+20*time.Millisecond && w.targetRate < w.maxRate { + w.targetRate += w.minRate + } + + if debugCC { + w.log() + log.Println("RTT", w.curRTT, "target rate", w.targetRate, "current rate", w.currentRate, "current window", w.currentWindow) + } +} + +func (w *windowCC) log() { + if w.statsFile == nil { + return + } + fmt.Fprintf(w.statsFile, "%.02f,%d,%d,%d,%d,%d,%d,%.02f,%.02f\n", time.Since(w.start).Seconds()*1000, w.minWindow, w.maxWindow, w.currentWindow, w.minRate, w.maxRate, w.currentRate, w.minRTT.Seconds()*1000, w.curRTT.Seconds()*1000) +} diff --git a/core/lib/buf/leakybuf.go b/core/lib/buf/leakybuf.go new file mode 100644 index 0000000..ee56728 --- /dev/null +++ b/core/lib/buf/leakybuf.go @@ -0,0 +1,52 @@ +// Provides leaky buffer, based on the example in Effective Go. +package buf + +type LeakyBuf struct { + bufSize int // size of each buffer + freeList chan []byte +} + +const LeakyBufSize = 2048 // data.len(2) + hmacsha1(10) + data(4096) +const maxNBuf = 2048 + +var LeakyBuffer = NewLeakyBuf(maxNBuf, LeakyBufSize) + +func Get() (b []byte) { + return LeakyBuffer.Get() +} +func Put(b []byte) { + LeakyBuffer.Put(b) +} + +// NewLeakyBuf creates a leaky buffer which can hold at most n buffer, each +// with bufSize bytes. +func NewLeakyBuf(n, bufSize int) *LeakyBuf { + return &LeakyBuf{ + bufSize: bufSize, + freeList: make(chan []byte, n), + } +} + +// Get returns a buffer from the leaky buffer or create a new buffer. +func (lb *LeakyBuf) Get() (b []byte) { + select { + case b = <-lb.freeList: + default: + b = make([]byte, lb.bufSize) + } + return +} + +// Put add the buffer into the free buffer pool for reuse. Panic if the buffer +// size is not the same with the leaky buffer's. This is intended to expose +// error usage of leaky buffer. +func (lb *LeakyBuf) Put(b []byte) { + if len(b) != lb.bufSize { + panic("invalid buffer size that's put into leaky buffer") + } + select { + case lb.freeList <- b: + default: + } + return +} diff --git a/core/lib/ioutils/utils.go b/core/lib/ioutils/utils.go new file mode 100644 index 0000000..70560cb --- /dev/null +++ b/core/lib/ioutils/utils.go @@ -0,0 +1,68 @@ +package ioutils + +import ( + "io" + logger "log" + + lbuf "github.com/snail007/goproxy/core/lib/buf" +) + +func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("bind crashed %s", err) + } + }() + e1 := make(chan interface{}, 1) + e2 := make(chan interface{}, 1) + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("bind crashed %s", err) + } + }() + //_, err := io.Copy(dst, src) + err := ioCopy(dst, src) + e1 <- err + }() + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("bind crashed %s", err) + } + }() + //_, err := io.Copy(src, dst) + err := ioCopy(src, dst) + e2 <- err + }() + var err interface{} + select { + case err = <-e1: + //log.Printf("e1") + case err = <-e2: + //log.Printf("e2") + } + src.Close() + dst.Close() + if fn != nil { + fn(err) + } + }() +} +func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) { + buf := lbuf.LeakyBuffer.Get() + defer lbuf.LeakyBuffer.Put(buf) + n := 0 + for { + n, err = src.Read(buf) + if n > 0 { + if _, e := dst.Write(buf[0:n]); e != nil { + return e + } + } + if err != nil { + return + } + } +} diff --git a/core/lib/kcpcfg/args.go b/core/lib/kcpcfg/args.go new file mode 100644 index 0000000..5d3b67c --- /dev/null +++ b/core/lib/kcpcfg/args.go @@ -0,0 +1,24 @@ +package kcpcfg + +import kcp "github.com/xtaci/kcp-go" + +type KCPConfigArgs struct { + Key *string + Crypt *string + Mode *string + MTU *int + SndWnd *int + RcvWnd *int + DataShard *int + ParityShard *int + DSCP *int + NoComp *bool + AckNodelay *bool + NoDelay *int + Interval *int + Resend *int + NoCongestion *int + SockBuf *int + KeepAlive *int + Block kcp.BlockCrypt +} diff --git a/core/lib/mapx/map.go b/core/lib/mapx/map.go new file mode 100644 index 0000000..3c10df9 --- /dev/null +++ b/core/lib/mapx/map.go @@ -0,0 +1,315 @@ +package mapx + +import ( + "encoding/json" + "sync" +) + +var SHARD_COUNT = 32 + +// A "thread" safe map of type string:Anything. +// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. +type ConcurrentMap []*ConcurrentMapShared + +// A "thread" safe string to anything map. +type ConcurrentMapShared struct { + items map[string]interface{} + sync.RWMutex // Read Write mutex, guards access to internal map. +} + +// Creates a new concurrent map. +func NewConcurrentMap() ConcurrentMap { + m := make(ConcurrentMap, SHARD_COUNT) + for i := 0; i < SHARD_COUNT; i++ { + m[i] = &ConcurrentMapShared{items: make(map[string]interface{})} + } + return m +} + +// Returns shard under given key +func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared { + return m[uint(fnv32(key))%uint(SHARD_COUNT)] +} + +func (m ConcurrentMap) MSet(data map[string]interface{}) { + for key, value := range data { + shard := m.GetShard(key) + shard.Lock() + shard.items[key] = value + shard.Unlock() + } +} + +// Sets the given value under the specified key. +func (m ConcurrentMap) Set(key string, value interface{}) { + // Get map shard. + shard := m.GetShard(key) + shard.Lock() + shard.items[key] = value + shard.Unlock() +} + +// Callback to return new element to be inserted into the map +// It is called while lock is held, therefore it MUST NOT +// try to access other keys in same map, as it can lead to deadlock since +// Go sync.RWLock is not reentrant +type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{} + +// Insert or Update - updates existing element or inserts a new one using UpsertCb +func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) { + shard := m.GetShard(key) + shard.Lock() + v, ok := shard.items[key] + res = cb(ok, v, value) + shard.items[key] = res + shard.Unlock() + return res +} + +// Sets the given value under the specified key if no value was associated with it. +func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool { + // Get map shard. + shard := m.GetShard(key) + shard.Lock() + _, ok := shard.items[key] + if !ok { + shard.items[key] = value + } + shard.Unlock() + return !ok +} + +// Retrieves an element from map under given key. +func (m ConcurrentMap) Get(key string) (interface{}, bool) { + // Get shard + shard := m.GetShard(key) + shard.RLock() + // Get item from shard. + val, ok := shard.items[key] + shard.RUnlock() + return val, ok +} + +// Returns the number of elements within the map. +func (m ConcurrentMap) Count() int { + count := 0 + for i := 0; i < SHARD_COUNT; i++ { + shard := m[i] + shard.RLock() + count += len(shard.items) + shard.RUnlock() + } + return count +} + +// Looks up an item under specified key +func (m ConcurrentMap) Has(key string) bool { + // Get shard + shard := m.GetShard(key) + shard.RLock() + // See if element is within shard. + _, ok := shard.items[key] + shard.RUnlock() + return ok +} + +// Removes an element from the map. +func (m ConcurrentMap) Remove(key string) { + // Try to get shard. + shard := m.GetShard(key) + shard.Lock() + delete(shard.items, key) + shard.Unlock() +} + +// Removes an element from the map and returns it +func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) { + // Try to get shard. + shard := m.GetShard(key) + shard.Lock() + v, exists = shard.items[key] + delete(shard.items, key) + shard.Unlock() + return v, exists +} + +// Checks if map is empty. +func (m ConcurrentMap) IsEmpty() bool { + return m.Count() == 0 +} + +// Used by the Iter & IterBuffered functions to wrap two variables together over a channel, +type Tuple struct { + Key string + Val interface{} +} + +// Returns an iterator which could be used in a for range loop. +// +// Deprecated: using IterBuffered() will get a better performence +func (m ConcurrentMap) Iter() <-chan Tuple { + chans := snapshot(m) + ch := make(chan Tuple) + go fanIn(chans, ch) + return ch +} + +// Returns a buffered iterator which could be used in a for range loop. +func (m ConcurrentMap) IterBuffered() <-chan Tuple { + chans := snapshot(m) + total := 0 + for _, c := range chans { + total += cap(c) + } + ch := make(chan Tuple, total) + go fanIn(chans, ch) + return ch +} + +// Returns a array of channels that contains elements in each shard, +// which likely takes a snapshot of `m`. +// It returns once the size of each buffered channel is determined, +// before all the channels are populated using goroutines. +func snapshot(m ConcurrentMap) (chans []chan Tuple) { + chans = make([]chan Tuple, SHARD_COUNT) + wg := sync.WaitGroup{} + wg.Add(SHARD_COUNT) + // Foreach shard. + for index, shard := range m { + go func(index int, shard *ConcurrentMapShared) { + // Foreach key, value pair. + shard.RLock() + chans[index] = make(chan Tuple, len(shard.items)) + wg.Done() + for key, val := range shard.items { + chans[index] <- Tuple{key, val} + } + shard.RUnlock() + close(chans[index]) + }(index, shard) + } + wg.Wait() + return chans +} + +// fanIn reads elements from channels `chans` into channel `out` +func fanIn(chans []chan Tuple, out chan Tuple) { + wg := sync.WaitGroup{} + wg.Add(len(chans)) + for _, ch := range chans { + go func(ch chan Tuple) { + for t := range ch { + out <- t + } + wg.Done() + }(ch) + } + wg.Wait() + close(out) +} + +// Returns all items as map[string]interface{} +func (m ConcurrentMap) Items() map[string]interface{} { + tmp := make(map[string]interface{}) + + // Insert items to temporary map. + for item := range m.IterBuffered() { + tmp[item.Key] = item.Val + } + + return tmp +} + +// Iterator callback,called for every key,value found in +// maps. RLock is held for all calls for a given shard +// therefore callback sess consistent view of a shard, +// but not across the shards +type IterCb func(key string, v interface{}) + +// Callback based iterator, cheapest way to read +// all elements in a map. +func (m ConcurrentMap) IterCb(fn IterCb) { + for idx := range m { + shard := (m)[idx] + shard.RLock() + for key, value := range shard.items { + fn(key, value) + } + shard.RUnlock() + } +} + +// Return all keys as []string +func (m ConcurrentMap) Keys() []string { + count := m.Count() + ch := make(chan string, count) + go func() { + // Foreach shard. + wg := sync.WaitGroup{} + wg.Add(SHARD_COUNT) + for _, shard := range m { + go func(shard *ConcurrentMapShared) { + // Foreach key, value pair. + shard.RLock() + for key := range shard.items { + ch <- key + } + shard.RUnlock() + wg.Done() + }(shard) + } + wg.Wait() + close(ch) + }() + + // Generate keys + keys := make([]string, 0, count) + for k := range ch { + keys = append(keys, k) + } + return keys +} + +//Reviles ConcurrentMap "private" variables to json marshal. +func (m ConcurrentMap) MarshalJSON() ([]byte, error) { + // Create a temporary map, which will hold all item spread across shards. + tmp := make(map[string]interface{}) + + // Insert items to temporary map. + for item := range m.IterBuffered() { + tmp[item.Key] = item.Val + } + return json.Marshal(tmp) +} + +func fnv32(key string) uint32 { + hash := uint32(2166136261) + const prime32 = uint32(16777619) + for i := 0; i < len(key); i++ { + hash *= prime32 + hash ^= uint32(key[i]) + } + return hash +} + +// Concurrent map uses Interface{} as its value, therefor JSON Unmarshal +// will probably won't know which to type to unmarshal into, in such case +// we'll end up with a value of type map[string]interface{}, In most cases this isn't +// out value type, this is why we've decided to remove this functionality. + +// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) { +// // Reverse process of Marshal. + +// tmp := make(map[string]interface{}) + +// // Unmarshal into a single map. +// if err := json.Unmarshal(b, &tmp); err != nil { +// return nil +// } + +// // foreach key,value pair in temporary map insert into our concurrent map. +// for key, val := range tmp { +// m.Set(key, val) +// } +// return nil +// } diff --git a/core/lib/socks5/socks5.go b/core/lib/socks5/socks5.go new file mode 100644 index 0000000..e67b5ff --- /dev/null +++ b/core/lib/socks5/socks5.go @@ -0,0 +1,159 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "net" + "strconv" +) + +const ( + Method_NO_AUTH = uint8(0x00) + Method_GSSAPI = uint8(0x01) + Method_USER_PASS = uint8(0x02) + Method_IANA = uint8(0x7F) + Method_RESVERVE = uint8(0x80) + Method_NONE_ACCEPTABLE = uint8(0xFF) + VERSION_V5 = uint8(0x05) + CMD_CONNECT = uint8(0x01) + CMD_BIND = uint8(0x02) + CMD_ASSOCIATE = uint8(0x03) + ATYP_IPV4 = uint8(0x01) + ATYP_DOMAIN = uint8(0x03) + ATYP_IPV6 = uint8(0x04) + REP_SUCCESS = uint8(0x00) + REP_REQ_FAIL = uint8(0x01) + REP_RULE_FORBIDDEN = uint8(0x02) + REP_NETWOR_UNREACHABLE = uint8(0x03) + REP_HOST_UNREACHABLE = uint8(0x04) + REP_CONNECTION_REFUSED = uint8(0x05) + REP_TTL_TIMEOUT = uint8(0x06) + REP_CMD_UNSUPPORTED = uint8(0x07) + REP_ATYP_UNSUPPORTED = uint8(0x08) + REP_UNKNOWN = uint8(0x09) + RSV = uint8(0x00) +) + +var ( + ZERO_IP = []byte{0x00, 0x00, 0x00, 0x00} + ZERO_PORT = []byte{0x00, 0x00} +) +var Socks5Errors = []string{ + "", + "general failure", + "connection forbidden", + "network unreachable", + "host unreachable", + "connection refused", + "TTL expired", + "command not supported", + "address type not supported", +} + +// Auth contains authentication parameters that specific Dialers may require. +type UsernamePassword struct { + Username, Password string +} + +type PacketUDP struct { + rsv uint16 + frag uint8 + atype uint8 + dstHost string + dstPort string + data []byte +} + +func NewPacketUDP() (p PacketUDP) { + return PacketUDP{} +} +func (p *PacketUDP) Build(destAddr string, data []byte) (err error) { + host, port, err := net.SplitHostPort(destAddr) + if err != nil { + return + } + p.rsv = 0 + p.frag = 0 + p.dstHost = host + p.dstPort = port + p.atype = ATYP_IPV4 + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + p.atype = ATYP_IPV4 + ip = ip4 + } else { + p.atype = ATYP_IPV6 + } + } else { + if len(host) > 255 { + err = errors.New("proxy: destination host name too long: " + host) + return + } + p.atype = ATYP_DOMAIN + } + p.data = data + + return +} +func (p *PacketUDP) Parse(b []byte) (err error) { + p.frag = uint8(b[2]) + if p.frag != 0 { + err = fmt.Errorf("FRAG only support for 0 , %v ,%v", p.frag, b[:4]) + return + } + portIndex := 0 + p.atype = b[3] + switch p.atype { + case ATYP_IPV4: //IP V4 + p.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String() + portIndex = 8 + case ATYP_DOMAIN: //域名 + domainLen := uint8(b[4]) + p.dstHost = string(b[5 : 5+domainLen]) //b[4]表示域名的长度 + portIndex = int(5 + domainLen) + case ATYP_IPV6: //IP V6 + p.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String() + portIndex = 20 + } + p.dstPort = strconv.Itoa(int(b[portIndex])<<8 | int(b[portIndex+1])) + p.data = b[portIndex+2:] + return +} +func (p *PacketUDP) Header() []byte { + header := new(bytes.Buffer) + header.Write([]byte{0x00, 0x00, p.frag, p.atype}) + if p.atype == ATYP_IPV4 { + ip := net.ParseIP(p.dstHost) + header.Write(ip.To4()) + } else if p.atype == ATYP_IPV6 { + ip := net.ParseIP(p.dstHost) + header.Write(ip.To16()) + } else if p.atype == ATYP_DOMAIN { + hBytes := []byte(p.dstHost) + header.WriteByte(byte(len(hBytes))) + header.Write(hBytes) + } + port, _ := strconv.ParseUint(p.dstPort, 10, 64) + portBytes := new(bytes.Buffer) + binary.Write(portBytes, binary.BigEndian, port) + header.Write(portBytes.Bytes()[portBytes.Len()-2:]) + return header.Bytes() +} +func (p *PacketUDP) Bytes() []byte { + packBytes := new(bytes.Buffer) + packBytes.Write(p.Header()) + packBytes.Write(p.data) + return packBytes.Bytes() +} +func (p *PacketUDP) Host() string { + return p.dstHost +} + +func (p *PacketUDP) Port() string { + return p.dstPort +} +func (p *PacketUDP) Data() []byte { + return p.data +} diff --git a/core/lib/transport/compress.go b/core/lib/transport/compress.go new file mode 100644 index 0000000..c26767e --- /dev/null +++ b/core/lib/transport/compress.go @@ -0,0 +1,59 @@ +package transport + +import ( + "net" + "time" + + "github.com/golang/snappy" +) + +func NewCompStream(conn net.Conn) *CompStream { + c := new(CompStream) + c.conn = conn + c.w = snappy.NewBufferedWriter(conn) + c.r = snappy.NewReader(conn) + return c +} +func NewCompConn(conn net.Conn) net.Conn { + c := CompStream{} + c.conn = conn + c.w = snappy.NewBufferedWriter(conn) + c.r = snappy.NewReader(conn) + return &c +} + +type CompStream struct { + net.Conn + conn net.Conn + w *snappy.Writer + r *snappy.Reader +} + +func (c *CompStream) Read(p []byte) (n int, err error) { + return c.r.Read(p) +} + +func (c *CompStream) Write(p []byte) (n int, err error) { + n, err = c.w.Write(p) + err = c.w.Flush() + return n, err +} + +func (c *CompStream) Close() error { + return c.conn.Close() +} +func (c *CompStream) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} +func (c *CompStream) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} +func (c *CompStream) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} +func (c *CompStream) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} +func (c *CompStream) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/core/lib/transport/encrypt/conn.go b/core/lib/transport/encrypt/conn.go new file mode 100644 index 0000000..db35bd8 --- /dev/null +++ b/core/lib/transport/encrypt/conn.go @@ -0,0 +1,40 @@ +package encrypt + +import ( + "crypto/cipher" + "io" + "net" + + lbuf "github.com/snail007/goproxy/core/lib/buf" +) + +var ( + lBuf = lbuf.NewLeakyBuf(2048, 2048) +) + +type Conn struct { + net.Conn + *Cipher + w io.Writer + r io.Reader +} + +func NewConn(c net.Conn, method, password string) (conn net.Conn, err error) { + cipher0, err := NewCipher(method, password) + if err != nil { + return + } + conn = &Conn{ + Conn: c, + Cipher: cipher0, + r: &cipher.StreamReader{S: cipher0.ReadStream, R: c}, + w: &cipher.StreamWriter{S: cipher0.WriteStream, W: c}, + } + return +} +func (s *Conn) Read(b []byte) (n int, err error) { + return s.r.Read(b) +} +func (s *Conn) Write(b []byte) (n int, err error) { + return s.w.Write(b) +} diff --git a/core/lib/transport/encrypt/encrypt.go b/core/lib/transport/encrypt/encrypt.go new file mode 100644 index 0000000..23b3a3f --- /dev/null +++ b/core/lib/transport/encrypt/encrypt.go @@ -0,0 +1,185 @@ +package encrypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/md5" + "crypto/rc4" + "crypto/sha256" + "errors" + + lbuf "github.com/snail007/goproxy/core/lib/buf" + "github.com/Yawning/chacha20" + "golang.org/x/crypto/blowfish" + "golang.org/x/crypto/cast5" +) + +const leakyBufSize = 2048 +const maxNBuf = 2048 + +var leakyBuf = lbuf.NewLeakyBuf(maxNBuf, leakyBufSize) +var errEmptyPassword = errors.New("proxy key") + +func md5sum(d []byte) []byte { + h := md5.New() + h.Write(d) + return h.Sum(nil) +} + +func evpBytesToKey(password string, keyLen int) (key []byte) { + const md5Len = 16 + cnt := (keyLen-1)/md5Len + 1 + m := make([]byte, cnt*md5Len) + copy(m, md5sum([]byte(password))) + + // Repeatedly call md5 until bytes generated is enough. + // Each call to md5 uses data: prev md5 sum + password. + d := make([]byte, md5Len+len(password)) + start := 0 + for i := 1; i < cnt; i++ { + start += md5Len + copy(d, m[start-md5Len:start]) + copy(d[md5Len:], password) + copy(m[start:], md5sum(d)) + } + return m[:keyLen] +} + +type DecOrEnc int + +const ( + Decrypt DecOrEnc = iota + Encrypt +) + +func newStream(block cipher.Block, err error, key, iv []byte, + doe DecOrEnc) (cipher.Stream, error) { + if err != nil { + return nil, err + } + if doe == Encrypt { + return cipher.NewCFBEncrypter(block, iv), nil + } else { + return cipher.NewCFBDecrypter(block, iv), nil + } +} + +func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := aes.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(block, iv), nil +} + +func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := des.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newBlowFishStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := blowfish.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newCast5Stream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := cast5.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + h := md5.New() + h.Write(key) + h.Write(iv) + rc4key := h.Sum(nil) + + return rc4.NewCipher(rc4key) +} + +func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + return chacha20.NewCipher(key, iv) +} + +func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + return chacha20.NewCipher(key, iv) +} + +type cipherInfo struct { + keyLen int + ivLen int + newStream func(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) +} + +var cipherMethod = map[string]*cipherInfo{ + "aes-128-cfb": {16, 16, newAESCFBStream}, + "aes-192-cfb": {24, 16, newAESCFBStream}, + "aes-256-cfb": {32, 16, newAESCFBStream}, + "aes-128-ctr": {16, 16, newAESCTRStream}, + "aes-192-ctr": {24, 16, newAESCTRStream}, + "aes-256-ctr": {32, 16, newAESCTRStream}, + "des-cfb": {8, 8, newDESStream}, + "bf-cfb": {16, 8, newBlowFishStream}, + "cast5-cfb": {16, 8, newCast5Stream}, + "rc4-md5": {16, 16, newRC4MD5Stream}, + "rc4-md5-6": {16, 6, newRC4MD5Stream}, + "chacha20": {32, 8, newChaCha20Stream}, + "chacha20-ietf": {32, 12, newChaCha20IETFStream}, +} + +func GetCipherMethods() (keys []string) { + keys = []string{} + for k := range cipherMethod { + keys = append(keys, k) + } + return +} +func CheckCipherMethod(method string) error { + if method == "" { + method = "aes-256-cfb" + } + _, ok := cipherMethod[method] + if !ok { + return errors.New("Unsupported encryption method: " + method) + } + return nil +} + +type Cipher struct { + WriteStream cipher.Stream + ReadStream cipher.Stream + key []byte + info *cipherInfo +} + +func NewCipher(method, password string) (c *Cipher, err error) { + if password == "" { + return nil, errEmptyPassword + } + mi, ok := cipherMethod[method] + if !ok { + return nil, errors.New("Unsupported encryption method: " + method) + } + key := evpBytesToKey(password, mi.keyLen) + c = &Cipher{key: key, info: mi} + if err != nil { + return nil, err + } + //hash(key) -> read IV + riv := sha256.New().Sum(c.key)[:c.info.ivLen] + c.ReadStream, err = c.info.newStream(c.key, riv, Decrypt) + if err != nil { + return nil, err + } //hash(read IV) -> write IV + wiv := sha256.New().Sum(riv)[:c.info.ivLen] + c.WriteStream, err = c.info.newStream(c.key, wiv, Encrypt) + if err != nil { + return nil, err + } + return c, nil +} diff --git a/core/lib/udp/udp.go b/core/lib/udp/udp.go new file mode 100644 index 0000000..284f195 --- /dev/null +++ b/core/lib/udp/udp.go @@ -0,0 +1,212 @@ +package udputils + +import ( + logger "log" + "net" + "strings" + "time" + + bufx "github.com/snail007/goproxy/core/lib/buf" + mapx "github.com/snail007/goproxy/core/lib/mapx" +) + +type CreateOutUDPConnFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, packet []byte) (outconn *net.UDPConn, err error) +type CleanFn func(srcAddr string) +type BeforeSendFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, b []byte) (sendB []byte, err error) +type BeforeReplyFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, outconn *net.UDPConn, b []byte) (replyB []byte, err error) + +type IOBinder struct { + outConns mapx.ConcurrentMap + listener *net.UDPConn + createOutUDPConnFn CreateOutUDPConnFn + log *logger.Logger + timeout time.Duration + cleanFn CleanFn + inTCPConn *net.Conn + outTCPConn *net.Conn + beforeSendFn BeforeSendFn + beforeReplyFn BeforeReplyFn +} + +func NewIOBinder(listener *net.UDPConn, log *logger.Logger) *IOBinder { + return &IOBinder{ + listener: listener, + outConns: mapx.NewConcurrentMap(), + log: log, + } +} +func (s *IOBinder) Factory(fn CreateOutUDPConnFn) *IOBinder { + s.createOutUDPConnFn = fn + return s +} +func (s *IOBinder) AfterReadFromClient(fn BeforeSendFn) *IOBinder { + s.beforeSendFn = fn + return s +} +func (s *IOBinder) AfterReadFromServer(fn BeforeReplyFn) *IOBinder { + s.beforeReplyFn = fn + return s +} +func (s *IOBinder) Timeout(timeout time.Duration) *IOBinder { + s.timeout = timeout + return s +} +func (s *IOBinder) Clean(fn CleanFn) *IOBinder { + s.cleanFn = fn + return s +} +func (s *IOBinder) AliveWithServeConn(srcAddr string, inTCPConn *net.Conn) *IOBinder { + s.inTCPConn = inTCPConn + go func() { + buf := make([]byte, 1) + (*inTCPConn).SetReadDeadline(time.Time{}) + if _, err := (*inTCPConn).Read(buf); err != nil { + s.log.Printf("udp related tcp conn of client disconnected with read , %s", err.Error()) + s.clean(srcAddr) + } + }() + go func() { + for { + (*inTCPConn).SetWriteDeadline(time.Now().Add(time.Second * 5)) + if _, err := (*inTCPConn).Write([]byte{0x00}); err != nil { + s.log.Printf("udp related tcp conn of client disconnected with write , %s", err.Error()) + s.clean(srcAddr) + return + } + (*inTCPConn).SetWriteDeadline(time.Time{}) + time.Sleep(time.Second * 5) + } + }() + return s +} +func (s *IOBinder) AliveWithClientConn(srcAddr string, outTCPConn *net.Conn) *IOBinder { + s.outTCPConn = outTCPConn + go func() { + buf := make([]byte, 1) + (*outTCPConn).SetReadDeadline(time.Time{}) + if _, err := (*outTCPConn).Read(buf); err != nil { + s.log.Printf("udp related tcp conn to parent disconnected with read , %s", err.Error()) + s.clean(srcAddr) + } + }() + return s +} +func (s *IOBinder) Run() (err error) { + var ( + isClosedErr = func(err error) bool { + return err != nil && strings.Contains(err.Error(), "use of closed network connection") + } + isTimeoutErr = func(err error) bool { + if err == nil { + return false + } + e, ok := err.(net.Error) + return ok && e.Timeout() + } + isRefusedErr = func(err error) bool { + return err != nil && strings.Contains(err.Error(), "connection refused") + } + ) + for { + buf := bufx.Get() + defer bufx.Put(buf) + n, srcAddr, err := s.listener.ReadFromUDP(buf) + if err != nil { + s.log.Printf("read from client error %s", err) + if isClosedErr(err) { + return err + } + continue + } + var data []byte + if s.beforeSendFn != nil { + data, err = s.beforeSendFn(s.listener, srcAddr, buf[:n]) + if err != nil { + s.log.Printf("beforeSend retured an error , %s", err) + continue + } + } else { + data = buf[:n] + } + inconnRemoteAddr := srcAddr.String() + var outconn *net.UDPConn + if v, ok := s.outConns.Get(inconnRemoteAddr); !ok { + outconn, err = s.createOutUDPConnFn(s.listener, srcAddr, data) + if err != nil { + s.log.Printf("connnect fail %s", err) + return err + } + go func() { + defer func() { + s.clean(srcAddr.String()) + }() + buf := bufx.Get() + defer bufx.Put(buf) + for { + if s.timeout > 0 { + outconn.SetReadDeadline(time.Now().Add(s.timeout)) + } + n, srcAddr, err := outconn.ReadFromUDP(buf) + if err != nil { + s.log.Printf("read from remote error %s", err) + if isClosedErr(err) || isTimeoutErr(err) || isRefusedErr(err) { + return + } + continue + } + data := buf[:n] + if s.beforeReplyFn != nil { + data, err = s.beforeReplyFn(s.listener, srcAddr, outconn, buf[:n]) + if err != nil { + s.log.Printf("beforeReply retured an error , %s", err) + continue + } + } + _, err = s.listener.WriteTo(data, srcAddr) + if err != nil { + s.log.Printf("write to remote error %s", err) + if isClosedErr(err) { + return + } + continue + } + } + }() + } else { + outconn = v.(*net.UDPConn) + } + + s.log.Printf("use decrpyted data , %v", data) + + _, err = outconn.Write(data) + + if err != nil { + s.log.Printf("write to remote error %s", err) + if isClosedErr(err) { + return err + } + } + } +} +func (s *IOBinder) clean(srcAddr string) *IOBinder { + if v, ok := s.outConns.Get(srcAddr); ok { + (*v.(*net.UDPConn)).Close() + s.outConns.Remove(srcAddr) + } + if s.inTCPConn != nil { + (*s.inTCPConn).Close() + } + if s.outTCPConn != nil { + (*s.outTCPConn).Close() + } + if s.cleanFn != nil { + s.cleanFn(srcAddr) + } + return s +} + +func (s *IOBinder) Close() { + for _, c := range s.outConns.Items() { + (*c.(*net.UDPConn)).Close() + } +} diff --git a/core/proxy/client/proxy.go b/core/proxy/client/proxy.go new file mode 100644 index 0000000..ffc3209 --- /dev/null +++ b/core/proxy/client/proxy.go @@ -0,0 +1,31 @@ +// Package proxy provides support for a variety of protocols to proxy network +// data. +package client + +import ( + "net" + "time" + + socks5c "github.com/snail007/goproxy/core/lib/socks5" + socks5 "github.com/snail007/goproxy/core/proxy/client/socks5" +) + +// A Dialer is a means to establish a connection. +type Dialer interface { + // Dial connects to the given address via the proxy. + DialConn(conn *net.Conn, network, addr string) (err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type Auth struct { + User, Password string +} + +func SOCKS5(timeout time.Duration, auth *Auth) (Dialer, error) { + var a *socks5c.UsernamePassword + if auth != nil { + a = &socks5c.UsernamePassword{auth.User, auth.Password} + } + d := socks5.NewDialer(a, timeout) + return d, nil +} diff --git a/core/proxy/client/socks5/socks5.go b/core/proxy/client/socks5/socks5.go new file mode 100644 index 0000000..61f82a3 --- /dev/null +++ b/core/proxy/client/socks5/socks5.go @@ -0,0 +1,263 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strconv" + "time" + + socks5c "github.com/snail007/goproxy/core/lib/socks5" +) + +type Dialer struct { + timeout time.Duration + usernamePassword *socks5c.UsernamePassword +} + +// NewDialer returns a new Dialer that dials through the provided +// proxy server's network and address. +func NewDialer(auth *socks5c.UsernamePassword, timeout time.Duration) *Dialer { + if auth != nil && auth.Password == "" && auth.Username == "" { + auth = nil + } + return &Dialer{ + usernamePassword: auth, + timeout: timeout, + } +} + +func (d *Dialer) DialConn(conn *net.Conn, network, addr string) (err error) { + client := NewClientConn(conn, network, addr, d.timeout, d.usernamePassword, nil) + err = client._Handshake() + return +} + +type ClientConn struct { + user string + password string + conn *net.Conn + header []byte + timeout time.Duration + addr string + network string + udpAddr string +} + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address +// with an optional username and password. See RFC 1928 and RFC 1929. +// target must be a canonical address with a host and port. +// network : tcp udp +func NewClientConn(conn *net.Conn, network, target string, timeout time.Duration, auth *socks5c.UsernamePassword, header []byte) *ClientConn { + s := &ClientConn{ + conn: conn, + network: network, + timeout: timeout, + } + if auth != nil { + s.user = auth.Username + s.password = auth.Password + } + if header != nil && len(header) > 0 { + s.header = header + } + if network == "udp" && target == "" { + target = "0.0.0.0:1" + } + s.addr = target + return s +} + +// connect takes an existing connection to a socks5 proxy server, +// and commands the server to extend that connection to target, +// which must be a canonical address with a host and port. +func (s *ClientConn) _Handshake() error { + host, portStr, err := net.SplitHostPort(s.addr) + if err != nil { + return err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return errors.New("proxy: failed to parse port number: " + portStr) + } + if port < 1 || port > 0xffff { + return errors.New("proxy: port number out of range: " + portStr) + } + + if err := s.auth(host); err != nil { + return err + } + buf := []byte{} + if s.network == "tcp" { + buf = append(buf, socks5c.VERSION_V5, socks5c.CMD_CONNECT, 0 /* reserved */) + + } else { + buf = append(buf, socks5c.VERSION_V5, socks5c.CMD_ASSOCIATE, 0 /* reserved */) + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + buf = append(buf, socks5c.ATYP_IPV4) + ip = ip4 + } else { + buf = append(buf, socks5c.ATYP_IPV6) + } + buf = append(buf, ip...) + } else { + if len(host) > 255 { + return errors.New("proxy: destination host name too long: " + host) + } + buf = append(buf, socks5c.ATYP_DOMAIN) + buf = append(buf, byte(len(host))) + buf = append(buf, host...) + } + buf = append(buf, byte(port>>8), byte(port)) + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := (*s.conn).Write(buf); err != nil { + return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + (*s.conn).SetDeadline(time.Time{}) + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := io.ReadFull((*s.conn), buf[:4]); err != nil { + return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + (*s.conn).SetDeadline(time.Time{}) + failure := "unknown error" + if int(buf[1]) < len(socks5c.Socks5Errors) { + failure = socks5c.Socks5Errors[buf[1]] + } + + if len(failure) > 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) + } + + bytesToDiscard := 0 + switch buf[3] { + case socks5c.ATYP_IPV4: + bytesToDiscard = net.IPv4len + case socks5c.ATYP_IPV6: + bytesToDiscard = net.IPv6len + case socks5c.ATYP_DOMAIN: + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + _, err := io.ReadFull((*s.conn), buf[:1]) + (*s.conn).SetDeadline(time.Time{}) + if err != nil { + return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + bytesToDiscard = int(buf[0]) + default: + return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) + } + + if cap(buf) < bytesToDiscard { + buf = make([]byte, bytesToDiscard) + } else { + buf = buf[:bytesToDiscard] + } + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := io.ReadFull((*s.conn), buf); err != nil { + return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + (*s.conn).SetDeadline(time.Time{}) + var ip net.IP + ip = buf + ipStr := "" + if bytesToDiscard == net.IPv4len || bytesToDiscard == net.IPv6len { + if ipv4 := ip.To4(); ipv4 != nil { + ipStr = ipv4.String() + } else { + ipStr = ip.To16().String() + } + } + //log.Printf("%v", ipStr) + // Also need to discard the port number + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil { + return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + p := binary.BigEndian.Uint16([]byte{buf[0], buf[1]}) + //log.Printf("%v", p) + s.udpAddr = net.JoinHostPort(ipStr, fmt.Sprintf("%d", p)) + //log.Printf("%v", s.udpAddr) + (*s.conn).SetDeadline(time.Time{}) + return nil +} +func (s *ClientConn) SendUDP(data []byte, addr string) (respData []byte, err error) { + + c, err := net.DialTimeout("udp", s.udpAddr, s.timeout) + if err != nil { + return + } + conn := c.(*net.UDPConn) + + p := socks5c.NewPacketUDP() + p.Build(addr, data) + conn.SetDeadline(time.Now().Add(s.timeout)) + conn.Write(p.Bytes()) + conn.SetDeadline(time.Time{}) + + buf := make([]byte, 1024) + conn.SetDeadline(time.Now().Add(s.timeout)) + n, _, err := conn.ReadFrom(buf) + conn.SetDeadline(time.Time{}) + if err != nil { + return + } + respData = buf[:n] + return +} +func (s *ClientConn) auth(host string) error { + + // the size here is just an estimate + buf := make([]byte, 0, 6+len(host)) + + buf = append(buf, socks5c.VERSION_V5) + if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { + buf = append(buf, 2 /* num auth methods */, socks5c.Method_NO_AUTH, socks5c.Method_USER_PASS) + } else { + buf = append(buf, 1 /* num auth methods */, socks5c.Method_NO_AUTH) + } + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := (*s.conn).Write(buf); err != nil { + return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + (*s.conn).SetDeadline(time.Time{}) + + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil { + return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + (*s.conn).SetDeadline(time.Time{}) + + if buf[0] != 5 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) + } + if buf[1] == 0xff { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") + } + + // See RFC 1929 + if buf[1] == socks5c.Method_USER_PASS { + buf = buf[:0] + buf = append(buf, 1 /* password protocol version */) + buf = append(buf, uint8(len(s.user))) + buf = append(buf, s.user...) + buf = append(buf, uint8(len(s.password))) + buf = append(buf, s.password...) + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := (*s.conn).Write(buf); err != nil { + return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + (*s.conn).SetDeadline(time.Time{}) + (*s.conn).SetDeadline(time.Now().Add(s.timeout)) + if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil { + return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + (*s.conn).SetDeadline(time.Time{}) + if buf[1] != 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") + } + } + return nil +} diff --git a/core/proxy/client/tests/proxy_test.go b/core/proxy/client/tests/proxy_test.go new file mode 100644 index 0000000..28f8645 --- /dev/null +++ b/core/proxy/client/tests/proxy_test.go @@ -0,0 +1,79 @@ +package tests + +import ( + "io/ioutil" + "net" + "os" + "strings" + "testing" + "time" + + proxyclient "github.com/snail007/goproxy/core/proxy/client" + sdk "github.com/snail007/goproxy/sdk/android-ios" +) + +func TestSocks5(t *testing.T) { + estr := sdk.Start("s1", "socks -p :8185 --log test.log") + if estr != "" { + t.Fatal(estr) + } + p, e := proxyclient.SOCKS5(time.Second, nil) + if e != nil { + t.Error(e) + } else { + c, e := net.Dial("tcp", "127.0.0.1:8185") + if e != nil { + t.Fatal(e) + } + e = p.DialConn(&c, "tcp", "www.baidu.com:80") + if e != nil { + t.Fatal(e) + } + _, e = c.Write([]byte("Get / http/1.1\r\nHost: www.baidu.com\r\n")) + if e != nil { + t.Fatal(e) + } + b, e := ioutil.ReadAll(c) + if e != nil { + t.Fatal(e) + } + if !strings.HasPrefix(string(b), "HTTP") { + t.Fatalf("request baidu fail:%s", string(b)) + } + } + sdk.Stop("s1") + os.Remove("test.log") +} + +func TestSocks5Auth(t *testing.T) { + estr := sdk.Start("s1", "socks -p :8185 -a u:p --log test.log") + if estr != "" { + t.Fatal(estr) + } + p, e := proxyclient.SOCKS5(time.Second, &proxyclient.Auth{User: "u", Password: "p"}) + if e != nil { + t.Error(e) + } else { + c, e := net.Dial("tcp", "127.0.0.1:8185") + if e != nil { + t.Fatal(e) + } + e = p.DialConn(&c, "tcp", "www.baidu.com:80") + if e != nil { + t.Fatal(e) + } + _, e = c.Write([]byte("Get / http/1.1\r\nHost: www.baidu.com\r\n")) + if e != nil { + t.Fatal(e) + } + b, e := ioutil.ReadAll(c) + if e != nil { + t.Fatal(e) + } + if !strings.HasPrefix(string(b), "HTTP") { + t.Fatalf("request baidu fail:%s", string(b)) + } + } + sdk.Stop("s1") + os.Remove("test.log") +} diff --git a/core/proxy/server/socks5/server.go b/core/proxy/server/socks5/server.go new file mode 100644 index 0000000..fff8d03 --- /dev/null +++ b/core/proxy/server/socks5/server.go @@ -0,0 +1,373 @@ +package socks5 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + socks5c "github.com/snail007/goproxy/core/proxy/common/socks5" +) + +type BasicAuther interface { + CheckUserPass(username, password, fromIP, ToTarget string) bool +} +type Request struct { + ver uint8 + cmd uint8 + reserve uint8 + addressType uint8 + dstAddr string + dstPort string + dstHost string + bytes []byte + rw io.ReadWriter +} + +func NewRequest(rw io.ReadWriter, header ...[]byte) (req Request, err interface{}) { + var b = make([]byte, 1024) + var n int + req = Request{rw: rw} + if header != nil && len(header) == 1 && len(header[0]) > 1 { + b = header[0] + n = len(header[0]) + } else { + n, err = rw.Read(b[:]) + if err != nil { + err = fmt.Errorf("read req data fail,ERR: %s", err) + return + } + } + req.ver = uint8(b[0]) + req.cmd = uint8(b[1]) + req.reserve = uint8(b[2]) + req.addressType = uint8(b[3]) + if b[0] != 0x5 { + err = fmt.Errorf("sosck version supported") + req.TCPReply(socks5c.REP_REQ_FAIL) + return + } + switch b[3] { + case 0x01: //IP V4 + req.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String() + case 0x03: //域名 + req.dstHost = string(b[5 : n-2]) //b[4]表示域名的长度 + case 0x04: //IP V6 + req.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String() + } + req.dstPort = strconv.Itoa(int(b[n-2])<<8 | int(b[n-1])) + req.dstAddr = net.JoinHostPort(req.dstHost, req.dstPort) + req.bytes = b[:n] + return +} +func (s *Request) Bytes() []byte { + return s.bytes +} +func (s *Request) Addr() string { + return s.dstAddr +} +func (s *Request) Host() string { + return s.dstHost +} +func (s *Request) Port() string { + return s.dstPort +} +func (s *Request) AType() uint8 { + return s.addressType +} +func (s *Request) CMD() uint8 { + return s.cmd +} + +func (s *Request) TCPReply(rep uint8) (err error) { + _, err = s.rw.Write(s.NewReply(rep, "0.0.0.0:0")) + return +} +func (s *Request) UDPReply(rep uint8, addr string) (err error) { + _, err = s.rw.Write(s.NewReply(rep, addr)) + return +} +func (s *Request) NewReply(rep uint8, addr string) []byte { + var response bytes.Buffer + host, port, _ := net.SplitHostPort(addr) + ip := net.ParseIP(host) + ipb := ip.To4() + atyp := socks5c.ATYP_IPV4 + ipv6 := ip.To16() + zeroiIPv6 := fmt.Sprintf("%d%d%d%d%d%d%d%d%d%d%d%d", + ipv6[0], ipv6[1], ipv6[2], ipv6[3], + ipv6[4], ipv6[5], ipv6[6], ipv6[7], + ipv6[8], ipv6[9], ipv6[10], ipv6[11], + ) + if ipb == nil && ipv6 != nil && "0000000000255255" != zeroiIPv6 { + atyp = socks5c.ATYP_IPV6 + ipb = ip.To16() + } + porti, _ := strconv.Atoi(port) + portb := make([]byte, 2) + binary.BigEndian.PutUint16(portb, uint16(porti)) + // log.Printf("atyp : %v", atyp) + // log.Printf("ip : %v", []byte(ip)) + response.WriteByte(socks5c.VERSION_V5) + response.WriteByte(rep) + response.WriteByte(socks5c.RSV) + response.WriteByte(atyp) + response.Write(ipb) + response.Write(portb) + return response.Bytes() +} + +type MethodsRequest struct { + ver uint8 + methodsCount uint8 + methods []uint8 + bytes []byte + rw *io.ReadWriter +} + +func NewMethodsRequest(r io.ReadWriter, header ...[]byte) (s MethodsRequest, err interface{}) { + defer func() { + if err == nil { + err = recover() + } + }() + s = MethodsRequest{} + s.rw = &r + var buf = make([]byte, 300) + var n int + if header != nil && len(header) == 1 && len(header[0]) > 1 { + buf = header[0] + n = len(header[0]) + } else { + n, err = r.Read(buf) + if err != nil { + return + } + } + if buf[0] != 0x05 { + err = fmt.Errorf("socks version not supported") + return + } + if n != int(buf[1])+int(2) { + err = fmt.Errorf("socks methods data length error") + return + } + s.ver = buf[0] + s.methodsCount = buf[1] + s.methods = buf[2:n] + s.bytes = buf[:n] + return +} +func (s *MethodsRequest) Version() uint8 { + return s.ver +} +func (s *MethodsRequest) MethodsCount() uint8 { + return s.methodsCount +} +func (s *MethodsRequest) Methods() []uint8 { + return s.methods +} +func (s *MethodsRequest) Select(method uint8) bool { + for _, m := range s.methods { + if m == method { + return true + } + } + return false +} +func (s *MethodsRequest) Reply(method uint8) (err error) { + _, err = (*s.rw).Write([]byte{byte(socks5c.VERSION_V5), byte(method)}) + return +} +func (s *MethodsRequest) Bytes() []byte { + return s.bytes +} + +type ServerConn struct { + target string + user string + password string + conn *net.Conn + timeout time.Duration + auth *BasicAuther + header []byte + ver uint8 + //method + methodsCount uint8 + methods []uint8 + method uint8 + //request + cmd uint8 + reserve uint8 + addressType uint8 + dstAddr string + dstPort string + dstHost string + udpAddress string +} + +func NewServerConn(conn *net.Conn, timeout time.Duration, auth *BasicAuther, udpAddress string, header []byte) *ServerConn { + if udpAddress == "" { + udpAddress = "0.0.0.0:16666" + } + s := &ServerConn{ + conn: conn, + timeout: timeout, + auth: auth, + header: header, + ver: socks5c.VERSION_V5, + udpAddress: udpAddress, + } + return s + +} +func (s *ServerConn) Close() { + (*s.conn).Close() +} +func (s *ServerConn) AuthData() socks5c.UsernamePassword { + return socks5c.UsernamePassword{s.user, s.password} +} +func (s *ServerConn) Method() uint8 { + return s.method +} +func (s *ServerConn) Target() string { + return s.target +} +func (s *ServerConn) Handshake() (err error) { + remoteAddr := (*s.conn).RemoteAddr() + //协商开始 + //method select request + var methodReq MethodsRequest + (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + + methodReq, e := NewMethodsRequest((*s.conn), s.header) + (*s.conn).SetReadDeadline(time.Time{}) + if e != nil { + (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + methodReq.Reply(socks5c.Method_NONE_ACCEPTABLE) + (*s.conn).SetReadDeadline(time.Time{}) + err = fmt.Errorf("new methods request fail,ERR: %s", e) + return + } + //log.Printf("%v,s.auth == %v && methodReq.Select(Method_NO_AUTH) %v", methodReq.methods, s.auth, methodReq.Select(Method_NO_AUTH)) + if s.auth == nil && methodReq.Select(socks5c.Method_NO_AUTH) && !methodReq.Select(socks5c.Method_USER_PASS) { + // if !methodReq.Select(Method_NO_AUTH) { + // (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + // methodReq.Reply(Method_NONE_ACCEPTABLE) + // (*s.conn).SetReadDeadline(time.Time{}) + // err = fmt.Errorf("none method found : Method_NO_AUTH") + // return + // } + s.method = socks5c.Method_NO_AUTH + //method select reply + (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + err = methodReq.Reply(socks5c.Method_NO_AUTH) + (*s.conn).SetReadDeadline(time.Time{}) + if err != nil { + err = fmt.Errorf("reply answer data fail,ERR: %s", err) + return + } + // err = fmt.Errorf("% x", methodReq.Bytes()) + } else { + //auth + if !methodReq.Select(socks5c.Method_USER_PASS) { + (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + methodReq.Reply(socks5c.Method_NONE_ACCEPTABLE) + (*s.conn).SetReadDeadline(time.Time{}) + err = fmt.Errorf("none method found : Method_USER_PASS") + return + } + s.method = socks5c.Method_USER_PASS + //method reply need auth + (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + err = methodReq.Reply(socks5c.Method_USER_PASS) + (*s.conn).SetReadDeadline(time.Time{}) + if err != nil { + err = fmt.Errorf("reply answer data fail,ERR: %s", err) + return + } + //read auth + buf := make([]byte, 500) + var n int + (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + n, err = (*s.conn).Read(buf) + (*s.conn).SetReadDeadline(time.Time{}) + if err != nil { + err = fmt.Errorf("read auth info fail,ERR: %s", err) + return + } + r := buf[:n] + s.user = string(r[2 : r[1]+2]) + s.password = string(r[2+r[1]+1:]) + //err = fmt.Errorf("user:%s,pass:%s", user, pass) + //auth + _addr := strings.Split(remoteAddr.String(), ":") + if s.auth == nil || (*s.auth).CheckUserPass(s.user, s.password, _addr[0], "") { + (*s.conn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(s.timeout))) + _, err = (*s.conn).Write([]byte{0x01, 0x00}) + (*s.conn).SetDeadline(time.Time{}) + if err != nil { + err = fmt.Errorf("answer auth success to %s fail,ERR: %s", remoteAddr, err) + return + } + } else { + (*s.conn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(s.timeout))) + _, err = (*s.conn).Write([]byte{0x01, 0x01}) + (*s.conn).SetDeadline(time.Time{}) + if err != nil { + err = fmt.Errorf("answer auth fail to %s fail,ERR: %s", remoteAddr, err) + return + } + err = fmt.Errorf("auth fail from %s", remoteAddr) + return + } + } + //request detail + (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout)) + request, e := NewRequest(*s.conn) + (*s.conn).SetReadDeadline(time.Time{}) + if e != nil { + err = fmt.Errorf("read request data fail,ERR: %s", e) + return + } + //协商结束 + + switch request.CMD() { + case socks5c.CMD_BIND: + err = request.TCPReply(socks5c.REP_UNKNOWN) + if err != nil { + err = fmt.Errorf("TCPReply REP_UNKNOWN to %s fail,ERR: %s", remoteAddr, err) + return + } + err = fmt.Errorf("cmd bind not supported, form: %s", remoteAddr) + return + case socks5c.CMD_CONNECT: + err = request.TCPReply(socks5c.REP_SUCCESS) + if err != nil { + err = fmt.Errorf("TCPReply REP_SUCCESS to %s fail,ERR: %s", remoteAddr, err) + return + } + case socks5c.CMD_ASSOCIATE: + err = request.UDPReply(socks5c.REP_SUCCESS, s.udpAddress) + if err != nil { + err = fmt.Errorf("UDPReply REP_SUCCESS to %s fail,ERR: %s", remoteAddr, err) + return + } + } + + //fill socks info + s.target = request.Addr() + s.methodsCount = methodReq.MethodsCount() + s.methods = methodReq.Methods() + s.cmd = request.CMD() + s.reserve = request.reserve + s.addressType = request.addressType + s.dstAddr = request.dstAddr + s.dstHost = request.dstHost + s.dstPort = request.dstPort + return +} diff --git a/core/tproxy/README.md b/core/tproxy/README.md new file mode 100644 index 0000000..13de09b --- /dev/null +++ b/core/tproxy/README.md @@ -0,0 +1,35 @@ +# 透传用户IP手册 + +说明: + +通过Linux的TPROXY功能,可以实现源站服务程序可以看见客户端真实IP,实现该功能需要linux操作系统和程序都要满足一定的条件. + +环境要求: + +源站必须是运行在Linux上面的服务程序,同时Linux需要满足下面条件: + +1.Linux内核版本 >= 2.6.28 + +2.判断系统是否支持TPROXY,执行: + + grep TPROXY /boot/config-`uname -r` + + 如果输出有下面的结果说明支持. + + CONFIG_NETFILTER_XT_TARGET_TPROXY=m + +部署步骤: + +1.在源站的linux系统里面每次开机启动都要用root权限执行tproxy环境设置脚本:tproxy_setup.sh + +2.在源站的linux系统里面使用root权限执行代理proxy + +参数 -tproxy 是开启代理的tproxy功能. + +./proxy -tproxy + +2.源站的程序监听的地址IP需要使用:127.0.1.1 + +比如源站以前监听的地址是: 0.0.0.0:8800 , 现在需要修改为:127.0.1.1:8800 + +3.转发规则里面源站地址必须是对应的,比如上面的:127.0.1.1:8800 diff --git a/core/tproxy/tproxy.go b/core/tproxy/tproxy.go new file mode 100644 index 0000000..ca754c2 --- /dev/null +++ b/core/tproxy/tproxy.go @@ -0,0 +1,249 @@ +// Package tproxy provides the TCPDial and TCPListen tproxy equivalent of the +// net package Dial and Listen with tproxy support for linux ONLY. +package tproxy + +import ( + "fmt" + "net" + "os" + "time" + + "golang.org/x/sys/unix" +) + +const big = 0xFFFFFF +const IP_ORIGADDRS = 20 + +// Debug outs the library in Debug mode +var Debug = false + +func ipToSocksAddr(family int, ip net.IP, port int, zone string) (unix.Sockaddr, error) { + switch family { + case unix.AF_INET: + if len(ip) == 0 { + ip = net.IPv4zero + } + if ip = ip.To4(); ip == nil { + return nil, net.InvalidAddrError("non-IPv4 address") + } + sa := new(unix.SockaddrInet4) + for i := 0; i < net.IPv4len; i++ { + sa.Addr[i] = ip[i] + } + sa.Port = port + return sa, nil + case unix.AF_INET6: + if len(ip) == 0 { + ip = net.IPv6zero + } + // IPv4 callers use 0.0.0.0 to mean "announce on any available address". + // In IPv6 mode, Linux treats that as meaning "announce on 0.0.0.0", + // which it refuses to do. Rewrite to the IPv6 unspecified address. + if ip.Equal(net.IPv4zero) { + ip = net.IPv6zero + } + if ip = ip.To16(); ip == nil { + return nil, net.InvalidAddrError("non-IPv6 address") + } + sa := new(unix.SockaddrInet6) + for i := 0; i < net.IPv6len; i++ { + sa.Addr[i] = ip[i] + } + sa.Port = port + sa.ZoneId = uint32(zoneToInt(zone)) + return sa, nil + } + return nil, net.InvalidAddrError("unexpected socket family") +} + +func zoneToInt(zone string) int { + if zone == "" { + return 0 + } + if ifi, err := net.InterfaceByName(zone); err == nil { + return ifi.Index + } + n, _, _ := dtoi(zone, 0) + return n +} + +func dtoi(s string, i0 int) (n int, i int, ok bool) { + n = 0 + for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { + n = n*10 + int(s[i]-'0') + if n >= big { + return 0, i, false + } + } + if i == i0 { + return 0, i, false + } + return n, i, true +} + +// IPTcpAddrToUnixSocksAddr --- +func IPTcpAddrToUnixSocksAddr(addr string) (sa unix.Sockaddr, err error) { + if Debug { + fmt.Println("DEBUG: IPTcpAddrToUnixSocksAddr recieved address:", addr) + } + addressNet := "tcp6" + if addr[0] != '[' { + addressNet = "tcp4" + } + tcpAddr, err := net.ResolveTCPAddr(addressNet, addr) + if err != nil { + return nil, err + } + return ipToSocksAddr(ipType(addr), tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone) +} + +// IPv6UdpAddrToUnixSocksAddr --- +func IPv6UdpAddrToUnixSocksAddr(addr string) (sa unix.Sockaddr, err error) { + tcpAddr, err := net.ResolveTCPAddr("udp6", addr) + if err != nil { + return nil, err + } + return ipToSocksAddr(unix.AF_INET6, tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone) +} + +// TCPListen is listening for incoming IP packets which are being intercepted. +// In conflict to regular Listen mehtod the socket destination and source addresses +// are of the intercepted connection. +// Else then that it works exactly like net package net.Listen. +func TCPListen(listenAddr string) (listener net.Listener, err error) { + s, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0) + if err != nil { + return nil, err + } + defer unix.Close(s) + err = unix.SetsockoptInt(s, unix.SOL_IP, unix.IP_TRANSPARENT, 1) + if err != nil { + return nil, err + } + + sa, err := IPTcpAddrToUnixSocksAddr(listenAddr) + if err != nil { + return nil, err + } + err = unix.Bind(s, sa) + if err != nil { + return nil, err + } + err = unix.Listen(s, unix.SOMAXCONN) + if err != nil { + return nil, err + } + f := os.NewFile(uintptr(s), "TProxy") + defer f.Close() + return net.FileListener(f) +} +func ipType(localAddr string) int { + host, _, _ := net.SplitHostPort(localAddr) + if host != "" { + ip := net.ParseIP(host) + if ip == nil || ip.To4() != nil { + return unix.AF_INET + } + return unix.AF_INET6 + } + return unix.AF_INET +} + +// TCPDial is a special tcp connection which binds a non local address as the source. +// Except then the option to bind to a specific local address which the machine doesn't posses +// it is exactly like any other net.Conn connection. +// It is advised to use port numbered 0 in the localAddr and leave the kernel to choose which +// Local port to use in order to avoid errors and binding conflicts. +func TCPDial(localAddr, remoteAddr string, timeout time.Duration) (conn net.Conn, err error) { + timer := time.NewTimer(timeout) + defer timer.Stop() + if Debug { + fmt.Println("TCPDial from:", localAddr, "to:", remoteAddr) + } + s, err := unix.Socket(ipType(localAddr), unix.SOCK_STREAM, 0) + + //In a case there was a need for a non-blocking socket an example + //s, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM |unix.SOCK_NONBLOCK, 0) + if err != nil { + fmt.Println(err) + return nil, err + } + defer unix.Close(s) + err = unix.SetsockoptInt(s, unix.SOL_IP, unix.IP_TRANSPARENT, 1) + if err != nil { + if Debug { + fmt.Println("ERROR setting the socket in IP_TRANSPARENT mode", err) + } + + return nil, err + } + err = unix.SetsockoptInt(s, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + if err != nil { + if Debug { + fmt.Println("ERROR setting the socket in unix.SO_REUSEADDR mode", err) + } + return nil, err + } + + rhost, _, err := net.SplitHostPort(localAddr) + if err != nil { + if Debug { + // fmt.Fprintln(os.Stderr, err) + fmt.Println("ERROR", err, "running net.SplitHostPort on address:", localAddr) + } + } + + sa, err := IPTcpAddrToUnixSocksAddr(rhost + ":0") + if err != nil { + if Debug { + fmt.Println("ERROR creating a hostaddres for the socker with IPTcpAddrToUnixSocksAddr", err) + } + return nil, err + } + + remoteSocket, err := IPTcpAddrToUnixSocksAddr(remoteAddr) + if err != nil { + if Debug { + fmt.Println("ERROR creating a remoteSocket for the socker with IPTcpAddrToUnixSocksAddr on the remote addres", err) + } + return nil, err + } + + err = unix.Bind(s, sa) + if err != nil { + fmt.Println(err) + return nil, err + } + + errChn := make(chan error, 1) + func() { + err = unix.Connect(s, remoteSocket) + if err != nil { + if Debug { + fmt.Println("ERROR Connecting from", s, "to:", remoteSocket, "ERROR:", err) + } + } + errChn <- err + }() + + select { + case err = <-errChn: + if err != nil { + return nil, err + } + case <-timer.C: + return nil, fmt.Errorf("ERROR connect to %s timeout", remoteAddr) + } + f := os.NewFile(uintptr(s), "TProxyTCPClient") + client, err := net.FileConn(f) + if err != nil { + if Debug { + fmt.Println("ERROR os.NewFile", err) + } + return nil, err + } + if Debug { + fmt.Println("FINISHED Creating net.coo from:", client.LocalAddr().String(), "to:", client.RemoteAddr().String()) + } + return client, err +} diff --git a/core/tproxy/tproxy_setup.sh b/core/tproxy/tproxy_setup.sh new file mode 100644 index 0000000..c5a6fd4 --- /dev/null +++ b/core/tproxy/tproxy_setup.sh @@ -0,0 +1,30 @@ +#!/bin/bash +SOURCE_BIND_IP="127.0.1.1" + +echo 0 > /proc/sys/net/ipv4/conf/lo/rp_filter +echo 2 > /proc/sys/net/ipv4/conf/default/rp_filter +echo 2 > /proc/sys/net/ipv4/conf/all/rp_filter +echo 1 > /proc/sys/net/ipv4/conf/all/send_redirects +echo 1 > /proc/sys/net/ipv4/conf/all/forwarding +echo 1 > /proc/sys/net/ipv4/ip_forward + +# 本地的话,貌似这段不需要 +# iptables -t mangle -N DIVERT >/dev/null 2>&1 +# iptables -t mangle -F DIVERT +# iptables -t mangle -D PREROUTING -p tcp -m socket -j DIVERT >/dev/null 2>&1 +# iptables -t mangle -A PREROUTING -p tcp -m socket -j DIVERT +# iptables -t mangle -A DIVERT -j MARK --set-mark 1 +# iptables -t mangle -A DIVERT -j ACCEPT + +ip rule del fwmark 1 lookup 100 +ip rule add fwmark 1 lookup 100 +ip route del local 0.0.0.0/0 dev lo table 100 +ip route add local 0.0.0.0/0 dev lo table 100 + +ip rule del from ${SOURCE_BIND_IP} table 101 +ip rule add from ${SOURCE_BIND_IP} table 101 +ip route del default via 127.0.0.1 dev lo table 101 +ip route add default via 127.0.0.1 dev lo table 101 + +ip route flush cache +ip ro flush cache \ No newline at end of file