goproxy/core/dst/mux.go
arraykeys@gmail.com 7bb5bb86dc 优化crashed日志
2018-09-14 11:47:50 +08:00

430 lines
9.6 KiB
Go

// Copyright 2014 The DST Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package dst
import (
"fmt"
"runtime/debug"
"net"
"sync"
"time"
)
const (
maxIncomingRequests = 1024
maxPacketSize = 500
handshakeTimeout = 5 * time.Second
handshakeInterval = 1 * time.Second
)
// Mux is a UDP multiplexer of DST connections.
type Mux struct {
conn net.PacketConn
packetSize int
conns map[connectionID]*Conn
handshakes map[connectionID]chan packet
connsMut sync.Mutex
incoming chan *Conn
closed chan struct{}
closeOnce sync.Once
buffers *sync.Pool
}
// NewMux creates a new DST Mux on top of a packet connection.
func NewMux(conn net.PacketConn, packetSize int) *Mux {
if packetSize <= 0 {
packetSize = maxPacketSize
}
m := &Mux{
conn: conn,
packetSize: packetSize,
conns: map[connectionID]*Conn{},
handshakes: make(map[connectionID]chan packet),
incoming: make(chan *Conn, maxIncomingRequests),
closed: make(chan struct{}),
buffers: &sync.Pool{
New: func() interface{} {
return make([]byte, packetSize)
},
},
}
// Attempt to maximize buffer space. Start at 16 MB and work downwards 0.5
// MB at a time.
if conn, ok := conn.(*net.UDPConn); ok {
for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
err := conn.SetReadBuffer(buf)
if err == nil {
if debugMux {
log.Println(m, "read buffer is", buf)
}
break
}
}
for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
err := conn.SetWriteBuffer(buf)
if err == nil {
if debugMux {
log.Println(m, "write buffer is", buf)
}
break
}
}
}
go func() {
defer func() {
if e := recover(); e != nil {
fmt.Printf("crashed, err: %s\nstack:", e, string(debug.Stack()))
}
}()
m.readerLoop()
}()
return m
}
// Accept waits for and returns the next connection to the listener.
func (m *Mux) Accept() (net.Conn, error) {
return m.AcceptDST()
}
// AcceptDST waits for and returns the next connection to the listener.
func (m *Mux) AcceptDST() (*Conn, error) {
conn, ok := <-m.incoming
if !ok {
return nil, ErrClosedMux
}
return conn, nil
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (m *Mux) Close() error {
var err error = ErrClosedMux
m.closeOnce.Do(func() {
err = m.conn.Close()
close(m.incoming)
close(m.closed)
})
return err
}
// Addr returns the listener's network address.
func (m *Mux) Addr() net.Addr {
return m.conn.LocalAddr()
}
// Dial connects to the address on the named network.
//
// Network must be "dst".
//
// Addresses have the form host:port. If host is a literal IPv6 address or
// host name, it must be enclosed in square brackets as in "[::1]:80",
// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
// SplitHostPort manipulate addresses in this form.
//
// Examples:
// Dial("dst", "12.34.56.78:80")
// Dial("dst", "google.com:http")
// Dial("dst", "[2001:db8::1]:http")
// Dial("dst", "[fe80::1%lo0]:80")
func (m *Mux) Dial(network, addr string) (net.Conn, error) {
return m.DialDST(network, addr)
}
// Dial connects to the address on the named network.
//
// Network must be "dst".
//
// Addresses have the form host:port. If host is a literal IPv6 address or
// host name, it must be enclosed in square brackets as in "[::1]:80",
// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
// SplitHostPort manipulate addresses in this form.
//
// Examples:
// Dial("dst", "12.34.56.78:80")
// Dial("dst", "google.com:http")
// Dial("dst", "[2001:db8::1]:http")
// Dial("dst", "[fe80::1%lo0]:80")
func (m *Mux) DialDST(network, addr string) (*Conn, error) {
if network != "dst" {
return nil, ErrNotDST
}
dst, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
resp := make(chan packet)
m.connsMut.Lock()
connID := m.newConnID()
m.handshakes[connID] = resp
m.connsMut.Unlock()
conn, err := m.clientHandshake(dst, connID, resp)
m.connsMut.Lock()
defer m.connsMut.Unlock()
delete(m.handshakes, connID)
if err != nil {
return nil, err
}
m.conns[connID] = conn
return conn, nil
}
// handshake performs the client side handshake (i.e. Dial)
func (m *Mux) clientHandshake(dst net.Addr, connID connectionID, resp chan packet) (*Conn, error) {
if debugMux {
log.Printf("%v dial %v connID %v", m, dst, connID)
}
nextHandshake := time.NewTimer(0)
defer nextHandshake.Stop()
handshakeTimeout := time.NewTimer(handshakeTimeout)
defer handshakeTimeout.Stop()
var remoteCookie uint32
seqNo := randomSeqNo()
for {
select {
case <-m.closed:
// Failure. The mux has been closed.
return nil, ErrClosedConn
case <-handshakeTimeout.C:
// Handshake timeout. Close and abort.
return nil, ErrHandshakeTimeout
case <-nextHandshake.C:
// Send a handshake request.
m.write(packet{
src: connID,
dst: dst,
hdr: header{
packetType: typeHandshake,
flags: flagRequest,
connID: 0,
sequenceNo: seqNo,
timestamp: timestampMicros(),
},
data: handshakeData{uint32(m.packetSize), connID, remoteCookie}.marshal(),
})
nextHandshake.Reset(handshakeInterval)
case pkt := <-resp:
hd := unmarshalHandshakeData(pkt.data)
if pkt.hdr.flags&flagCookie == flagCookie {
// We should resend the handshake request with a different cookie value.
remoteCookie = hd.cookie
nextHandshake.Reset(0)
} else if pkt.hdr.flags&flagResponse == flagResponse {
// Successfull handshake response.
conn := newConn(m, dst)
conn.connID = connID
conn.remoteConnID = hd.connID
conn.nextRecvSeqNo = pkt.hdr.sequenceNo + 1
conn.packetSize = int(hd.packetSize)
if conn.packetSize > m.packetSize {
conn.packetSize = m.packetSize
}
conn.nextSeqNo = seqNo + 1
conn.start()
return conn, nil
}
}
}
}
func (m *Mux) readerLoop() {
buf := make([]byte, m.packetSize)
for {
buf = buf[:cap(buf)]
n, from, err := m.conn.ReadFrom(buf)
if err != nil {
m.Close()
return
}
buf = buf[:n]
hdr := unmarshalHeader(buf)
var bufCopy []byte
if len(buf) > dstHeaderLen {
bufCopy = m.buffers.Get().([]byte)[:len(buf)-dstHeaderLen]
copy(bufCopy, buf[dstHeaderLen:])
}
pkt := packet{hdr: hdr, data: bufCopy}
if debugMux {
log.Println(m, "read", pkt)
}
if hdr.packetType == typeHandshake {
m.incomingHandshake(from, hdr, bufCopy)
} else {
m.connsMut.Lock()
conn, ok := m.conns[hdr.connID]
m.connsMut.Unlock()
if ok {
conn.in <- packet{
dst: nil,
hdr: hdr,
data: bufCopy,
}
} else if debugMux && hdr.packetType != typeShutdown {
log.Printf("packet %v for unknown conn %v", hdr, hdr.connID)
}
}
}
}
func (m *Mux) incomingHandshake(from net.Addr, hdr header, data []byte) {
if hdr.connID == 0 {
// A new incoming handshake request.
m.incomingHandshakeRequest(from, hdr, data)
} else {
// A response to an ongoing handshake.
m.incomingHandshakeResponse(from, hdr, data)
}
}
func (m *Mux) incomingHandshakeRequest(from net.Addr, hdr header, data []byte) {
if hdr.flags&flagRequest != flagRequest {
log.Printf("Handshake pattern with flags 0x%x to connID zero", hdr.flags)
return
}
hd := unmarshalHandshakeData(data)
correctCookie := cookie(from)
if hd.cookie != correctCookie {
// Incorrect or missing SYN cookie. Send back a handshake
// with the expected one.
m.write(packet{
dst: from,
hdr: header{
packetType: typeHandshake,
flags: flagResponse | flagCookie,
connID: hd.connID,
timestamp: timestampMicros(),
},
data: handshakeData{
packetSize: uint32(m.packetSize),
cookie: correctCookie,
}.marshal(),
})
return
}
seqNo := randomSeqNo()
m.connsMut.Lock()
connID := m.newConnID()
conn := newConn(m, from)
conn.connID = connID
conn.remoteConnID = hd.connID
conn.nextSeqNo = seqNo + 1
conn.nextRecvSeqNo = hdr.sequenceNo + 1
conn.packetSize = int(hd.packetSize)
if conn.packetSize > m.packetSize {
conn.packetSize = m.packetSize
}
conn.start()
m.conns[connID] = conn
m.connsMut.Unlock()
m.write(packet{
dst: from,
hdr: header{
packetType: typeHandshake,
flags: flagResponse,
connID: hd.connID,
sequenceNo: seqNo,
timestamp: timestampMicros(),
},
data: handshakeData{
connID: conn.connID,
packetSize: uint32(conn.packetSize),
}.marshal(),
})
m.incoming <- conn
}
func (m *Mux) incomingHandshakeResponse(from net.Addr, hdr header, data []byte) {
m.connsMut.Lock()
handShake, ok := m.handshakes[hdr.connID]
m.connsMut.Unlock()
if ok {
// This is a response to a handshake in progress.
handShake <- packet{
dst: nil,
hdr: hdr,
data: data,
}
} else if debugMux && hdr.packetType != typeShutdown {
log.Printf("Handshake packet %v for unknown conn %v", hdr, hdr.connID)
}
}
func (m *Mux) write(pkt packet) (int, error) {
buf := m.buffers.Get().([]byte)
buf = buf[:dstHeaderLen+len(pkt.data)]
pkt.hdr.marshal(buf)
copy(buf[dstHeaderLen:], pkt.data)
if debugMux {
log.Println(m, "write", pkt)
}
n, err := m.conn.WriteTo(buf, pkt.dst)
m.buffers.Put(buf)
return n, err
}
func (m *Mux) String() string {
return fmt.Sprintf("Mux-%v", m.Addr())
}
// Find a unique connection ID
func (m *Mux) newConnID() connectionID {
for {
connID := randomConnID()
if _, ok := m.conns[connID]; ok {
continue
}
if _, ok := m.handshakes[connID]; ok {
continue
}
return connID
}
}
func (m *Mux) removeConn(c *Conn) {
m.connsMut.Lock()
delete(m.conns, c.connID)
m.connsMut.Unlock()
}