add core
This commit is contained in:
132
core/cs/client/client.go
Normal file
132
core/cs/client/client.go
Normal file
@ -0,0 +1,132 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/snail007/goproxy/core/lib/kcpcfg"
|
||||
compressconn "github.com/snail007/goproxy/core/lib/transport"
|
||||
encryptconn "github.com/snail007/goproxy/core/lib/transport/encrypt"
|
||||
"github.com/snail007/goproxy/core/dst"
|
||||
kcp "github.com/xtaci/kcp-go"
|
||||
)
|
||||
|
||||
func TlsConnectHost(host string, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
|
||||
h := strings.Split(host, ":")
|
||||
port, _ := strconv.Atoi(h[1])
|
||||
return TlsConnect(h[0], port, timeout, certBytes, keyBytes, caCertBytes)
|
||||
}
|
||||
|
||||
func TlsConnect(host string, port, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
|
||||
conf, err := getRequestTlsConfig(certBytes, keyBytes, caCertBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return *tls.Client(_conn, conf), err
|
||||
}
|
||||
func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) {
|
||||
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
serverCertPool := x509.NewCertPool()
|
||||
caBytes := certBytes
|
||||
if caCertBytes != nil {
|
||||
caBytes = caCertBytes
|
||||
|
||||
}
|
||||
ok := serverCertPool.AppendCertsFromPEM(caBytes)
|
||||
if !ok {
|
||||
err = errors.New("failed to parse root certificate")
|
||||
}
|
||||
block, _ := pem.Decode(caBytes)
|
||||
if block == nil {
|
||||
panic("failed to parse certificate PEM")
|
||||
}
|
||||
x509Cert, _ := x509.ParseCertificate(block.Bytes)
|
||||
if x509Cert == nil {
|
||||
panic("failed to parse block")
|
||||
}
|
||||
conf = &tls.Config{
|
||||
RootCAs: serverCertPool,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: x509Cert.Subject.CommonName,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: serverCertPool,
|
||||
}
|
||||
for _, rawCert := range rawCerts {
|
||||
cert, _ := x509.ParseCertificate(rawCert)
|
||||
_, err := cert.Verify(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TCPConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
|
||||
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
|
||||
return
|
||||
}
|
||||
|
||||
func TCPSConnectHost(hostAndPort string, method, password string, compress bool, timeout int) (conn net.Conn, err error) {
|
||||
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if compress {
|
||||
conn = compressconn.NewCompConn(conn)
|
||||
}
|
||||
conn, err = encryptconn.NewConn(conn, method, password)
|
||||
return
|
||||
}
|
||||
|
||||
func TOUConnectHost(hostAndPort string, method, password string, compress bool, timeout int) (conn net.Conn, err error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Create a DST mux around the packet connection with the default max
|
||||
// packet size.
|
||||
mux := dst.NewMux(udpConn, 0)
|
||||
conn, err = mux.Dial("dst", hostAndPort)
|
||||
if compress {
|
||||
conn = compressconn.NewCompConn(conn)
|
||||
}
|
||||
conn, err = encryptconn.NewConn(conn, method, password)
|
||||
return
|
||||
}
|
||||
func KCPConnectHost(hostAndPort string, config kcpcfg.KCPConfigArgs) (conn net.Conn, err error) {
|
||||
kcpconn, err := kcp.DialWithOptions(hostAndPort, config.Block, *config.DataShard, *config.ParityShard)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
kcpconn.SetStreamMode(true)
|
||||
kcpconn.SetWriteDelay(true)
|
||||
kcpconn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion)
|
||||
kcpconn.SetMtu(*config.MTU)
|
||||
kcpconn.SetWindowSize(*config.SndWnd, *config.RcvWnd)
|
||||
kcpconn.SetACKNoDelay(*config.AckNodelay)
|
||||
if *config.NoComp {
|
||||
return kcpconn, err
|
||||
}
|
||||
return compressconn.NewCompStream(kcpconn), err
|
||||
}
|
||||
342
core/cs/server/server.go
Normal file
342
core/cs/server/server.go
Normal file
@ -0,0 +1,342 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
logger "log"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
|
||||
compressconn "github.com/snail007/goproxy/core/lib/transport"
|
||||
transportc "github.com/snail007/goproxy/core/lib/transport"
|
||||
encryptconn "github.com/snail007/goproxy/core/lib/transport/encrypt"
|
||||
tou "github.com/snail007/goproxy/core/dst"
|
||||
|
||||
"github.com/snail007/goproxy/core/lib/kcpcfg"
|
||||
|
||||
kcp "github.com/xtaci/kcp-go"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
}
|
||||
|
||||
type ServerChannel struct {
|
||||
ip string
|
||||
port int
|
||||
Listener *net.Listener
|
||||
UDPListener *net.UDPConn
|
||||
errAcceptHandler func(err error)
|
||||
log *logger.Logger
|
||||
TOUServer *tou.Mux
|
||||
}
|
||||
|
||||
func NewServerChannel(ip string, port int, log *logger.Logger) ServerChannel {
|
||||
return ServerChannel{
|
||||
ip: ip,
|
||||
port: port,
|
||||
log: log,
|
||||
errAcceptHandler: func(err error) {
|
||||
log.Printf("accept error , ERR:%s", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
func NewServerChannelHost(host string, log *logger.Logger) ServerChannel {
|
||||
h, port, _ := net.SplitHostPort(host)
|
||||
p, _ := strconv.Atoi(port)
|
||||
return ServerChannel{
|
||||
ip: h,
|
||||
port: p,
|
||||
log: log,
|
||||
errAcceptHandler: func(err error) {
|
||||
log.Printf("accept error , ERR:%s", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
func (s *ServerChannel) SetErrAcceptHandler(fn func(err error)) {
|
||||
s.errAcceptHandler = fn
|
||||
}
|
||||
func (s *ServerChannel) ListenSingleTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) {
|
||||
return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, true)
|
||||
|
||||
}
|
||||
func (s *ServerChannel) ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) {
|
||||
return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, false)
|
||||
}
|
||||
func (s *ServerChannel) _ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn), single bool) (err error) {
|
||||
s.Listener, err = s.listenTLS(s.ip, s.port, certBytes, keyBytes, caCertBytes, single)
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("ListenTLS crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
var conn net.Conn
|
||||
conn, err = (*s.Listener).Accept()
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("tls connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn(conn)
|
||||
}()
|
||||
} else {
|
||||
s.errAcceptHandler(err)
|
||||
(*s.Listener).Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
func (s *ServerChannel) listenTLS(ip string, port int, certBytes, keyBytes, caCertBytes []byte, single bool) (ln *net.Listener, err error) {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
config := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
if !single {
|
||||
clientCertPool := x509.NewCertPool()
|
||||
caBytes := certBytes
|
||||
if caCertBytes != nil {
|
||||
caBytes = caCertBytes
|
||||
}
|
||||
ok := clientCertPool.AppendCertsFromPEM(caBytes)
|
||||
if !ok {
|
||||
err = errors.New("failed to parse root certificate")
|
||||
}
|
||||
config.ClientCAs = clientCertPool
|
||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
_ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config)
|
||||
if err == nil {
|
||||
ln = &_ln
|
||||
}
|
||||
return
|
||||
}
|
||||
func (s *ServerChannel) ListenTCPS(method, password string, compress bool, fn func(conn net.Conn)) (err error) {
|
||||
_, err = encryptconn.NewCipher(method, password)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.ListenTCP(func(c net.Conn) {
|
||||
if compress {
|
||||
c = transportc.NewCompConn(c)
|
||||
}
|
||||
c, _ = encryptconn.NewConn(c, method, password)
|
||||
fn(c)
|
||||
})
|
||||
}
|
||||
func (s *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) {
|
||||
var l net.Listener
|
||||
l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.ip, s.port))
|
||||
if err == nil {
|
||||
s.Listener = &l
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
var conn net.Conn
|
||||
conn, err = (*s.Listener).Accept()
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn(conn)
|
||||
}()
|
||||
} else {
|
||||
s.errAcceptHandler(err)
|
||||
(*s.Listener).Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
func (s *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) {
|
||||
addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port}
|
||||
l, err := net.ListenUDP("udp", addr)
|
||||
if err == nil {
|
||||
s.UDPListener = l
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
var buf = make([]byte, 2048)
|
||||
n, srcAddr, err := (*s.UDPListener).ReadFromUDP(buf)
|
||||
if err == nil {
|
||||
packet := buf[0:n]
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn(packet, addr, srcAddr)
|
||||
}()
|
||||
} else {
|
||||
s.errAcceptHandler(err)
|
||||
(*s.UDPListener).Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
func (s *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net.Conn), log *logger.Logger) (err error) {
|
||||
lis, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", s.ip, s.port), config.Block, *config.DataShard, *config.ParityShard)
|
||||
if err == nil {
|
||||
if err = lis.SetDSCP(*config.DSCP); err != nil {
|
||||
log.Println("SetDSCP:", err)
|
||||
return
|
||||
}
|
||||
if err = lis.SetReadBuffer(*config.SockBuf); err != nil {
|
||||
log.Println("SetReadBuffer:", err)
|
||||
return
|
||||
}
|
||||
if err = lis.SetWriteBuffer(*config.SockBuf); err != nil {
|
||||
log.Println("SetWriteBuffer:", err)
|
||||
return
|
||||
}
|
||||
s.Listener = new(net.Listener)
|
||||
*s.Listener = lis
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("ListenKCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
//var conn net.Conn
|
||||
conn, err := lis.AcceptKCP()
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("kcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWriteDelay(true)
|
||||
conn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion)
|
||||
conn.SetMtu(*config.MTU)
|
||||
conn.SetWindowSize(*config.SndWnd, *config.RcvWnd)
|
||||
conn.SetACKNoDelay(*config.AckNodelay)
|
||||
if *config.NoComp {
|
||||
fn(conn)
|
||||
} else {
|
||||
cconn := transportc.NewCompStream(conn)
|
||||
fn(cconn)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
s.errAcceptHandler(err)
|
||||
(*s.Listener).Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *ServerChannel) ListenTOU(method, password string, compress bool, fn func(conn net.Conn)) (err error) {
|
||||
addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port}
|
||||
s.UDPListener, err = net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
s.log.Println(err)
|
||||
return
|
||||
}
|
||||
s.TOUServer = tou.NewMux(s.UDPListener, 0)
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("ListenRUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
var conn net.Conn
|
||||
conn, err = (*s.TOUServer).Accept()
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
if compress {
|
||||
conn = compressconn.NewCompConn(conn)
|
||||
}
|
||||
conn, err = encryptconn.NewConn(conn, method, password)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
s.log.Println(err)
|
||||
return
|
||||
}
|
||||
fn(conn)
|
||||
}()
|
||||
} else {
|
||||
s.errAcceptHandler(err)
|
||||
s.TOUServer.Close()
|
||||
s.UDPListener.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
func (s *ServerChannel) Close() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
if s.Listener != nil && *s.Listener != nil {
|
||||
(*s.Listener).Close()
|
||||
}
|
||||
if s.TOUServer != nil {
|
||||
s.TOUServer.Close()
|
||||
}
|
||||
if s.UDPListener != nil {
|
||||
s.UDPListener.Close()
|
||||
}
|
||||
}
|
||||
func (s *ServerChannel) Addr() string {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
if s.Listener != nil && *s.Listener != nil {
|
||||
return (*s.Listener).Addr().String()
|
||||
}
|
||||
|
||||
if s.UDPListener != nil {
|
||||
return s.UDPListener.LocalAddr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
49
core/cs/tests/transport_test.go
Normal file
49
core/cs/tests/transport_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
ctransport "github.com/snail007/goproxy/core/cs/client"
|
||||
stransport "github.com/snail007/goproxy/core/cs/server"
|
||||
)
|
||||
|
||||
func TestTCPS(t *testing.T) {
|
||||
l := log.New(os.Stderr, "", log.LstdFlags)
|
||||
s := stransport.NewServerChannelHost(":", l)
|
||||
err := s.ListenTCPS("aes-256-cfb", "password", true, func(inconn net.Conn) {
|
||||
buf := make([]byte, 2048)
|
||||
_, err := inconn.Read(buf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
_, err = inconn.Write([]byte("okay"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := ctransport.TCPSConnectHost((*s.Listener).Addr().String(), "aes-256-cfb", "password", true, 1000)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
_, err = client.Write([]byte("test"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b := make([]byte, 20)
|
||||
n, err := client.Read(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b[:n]) != "okay" {
|
||||
t.Fatalf("client revecive okay excepted,revecived : %s", string(b[:n]))
|
||||
}
|
||||
}
|
||||
586
core/dst/conn.go
Normal file
586
core/dst/conn.go
Normal file
@ -0,0 +1,586 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defExpTime = 100 * time.Millisecond // N * (4 * RTT + RTTVar + SYN)
|
||||
expCountClose = 8 // close connection after this many Exps
|
||||
minTimeClose = 5 * time.Second // if at least this long has passed
|
||||
maxInputBuffer = 8 << 20 // bytes
|
||||
muxBufferPackets = 128 // buffer size of channel between mux and reader routine
|
||||
rttMeasureWindow = 32 // number of packets to track for RTT averaging
|
||||
rttMeasureSample = 128 // Sample every ... packet for RTT
|
||||
|
||||
// number of bytes to subtract from MTU when chunking data, to try to
|
||||
// avoid fragmentation
|
||||
sliceOverhead = 8 /*pppoe, similar*/ + 20 /*ipv4*/ + 8 /*udp*/ + 16 /*dst*/
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Properly seed the random number generator that we use for sequence
|
||||
// numbers and stuff.
|
||||
buf := make([]byte, 8)
|
||||
if n, err := crand.Read(buf); n != 8 || err != nil {
|
||||
panic("init random failure")
|
||||
}
|
||||
rand.Seed(int64(binary.BigEndian.Uint64(buf)))
|
||||
}
|
||||
|
||||
// TODO: export this interface when it's usable from the outside
|
||||
type congestionController interface {
|
||||
Ack()
|
||||
NegAck()
|
||||
Exp()
|
||||
SendWindow() int
|
||||
PacketRate() int // PPS
|
||||
UpdateRTT(time.Duration)
|
||||
}
|
||||
|
||||
// Conn is an SDT connection carried over a Mux.
|
||||
type Conn struct {
|
||||
// Set at creation, thereafter immutable:
|
||||
|
||||
mux *Mux
|
||||
dst net.Addr
|
||||
connID connectionID
|
||||
remoteConnID connectionID
|
||||
in chan packet
|
||||
cc congestionController
|
||||
packetSize int
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
|
||||
// Touched by more than one goroutine, needs locking.
|
||||
|
||||
nextSeqNoMut sync.Mutex
|
||||
nextSeqNo sequenceNo
|
||||
|
||||
inbufMut sync.Mutex
|
||||
inbufCond *sync.Cond
|
||||
inbuf bytes.Buffer
|
||||
|
||||
expMut sync.Mutex
|
||||
exp *time.Timer
|
||||
|
||||
sendBuffer *sendBuffer // goroutine safe
|
||||
|
||||
packetDelays [rttMeasureWindow]time.Duration
|
||||
packetDelaysSlot int
|
||||
packetDelaysMut sync.Mutex
|
||||
|
||||
// Owned by the reader routine, needs no locking
|
||||
|
||||
recvBuffer packetList
|
||||
nextRecvSeqNo sequenceNo
|
||||
lastAckedSeqNo sequenceNo
|
||||
lastNegAckedSeqNo sequenceNo
|
||||
expCount int
|
||||
expReset time.Time
|
||||
|
||||
// Only accessed atomically
|
||||
|
||||
packetsIn int64
|
||||
packetsOut int64
|
||||
bytesIn int64
|
||||
bytesOut int64
|
||||
resentPackets int64
|
||||
droppedPackets int64
|
||||
outOfOrderPackets int64
|
||||
|
||||
// Special
|
||||
|
||||
debugResetRecvSeqNo chan sequenceNo
|
||||
}
|
||||
|
||||
func newConn(m *Mux, dst net.Addr) *Conn {
|
||||
conn := &Conn{
|
||||
mux: m,
|
||||
dst: dst,
|
||||
nextSeqNo: sequenceNo(rand.Uint32()),
|
||||
packetSize: maxPacketSize,
|
||||
in: make(chan packet, muxBufferPackets),
|
||||
closed: make(chan struct{}),
|
||||
sendBuffer: newSendBuffer(m),
|
||||
exp: time.NewTimer(defExpTime),
|
||||
debugResetRecvSeqNo: make(chan sequenceNo),
|
||||
expReset: time.Now(),
|
||||
}
|
||||
|
||||
conn.lastAckedSeqNo = conn.nextSeqNo - 1
|
||||
conn.inbufCond = sync.NewCond(&conn.inbufMut)
|
||||
|
||||
conn.cc = newWindowCC()
|
||||
conn.sendBuffer.SetWindowAndRate(conn.cc.SendWindow(), conn.cc.PacketRate())
|
||||
conn.recvBuffer.Resize(128)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (c *Conn) start() {
|
||||
go c.reader()
|
||||
}
|
||||
|
||||
func (c *Conn) reader() {
|
||||
if debugConnection {
|
||||
log.Println(c, "reader() starting")
|
||||
defer log.Println(c, "reader() exiting")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.closed:
|
||||
// Ack any received but not yet acked messages.
|
||||
c.sendAck(0)
|
||||
|
||||
// Send a shutdown message.
|
||||
c.nextSeqNoMut.Lock()
|
||||
c.mux.write(packet{
|
||||
src: c.connID,
|
||||
dst: c.dst,
|
||||
hdr: header{
|
||||
packetType: typeShutdown,
|
||||
connID: c.remoteConnID,
|
||||
sequenceNo: c.nextSeqNo,
|
||||
},
|
||||
})
|
||||
c.nextSeqNo++
|
||||
c.nextSeqNoMut.Unlock()
|
||||
atomic.AddInt64(&c.packetsOut, 1)
|
||||
atomic.AddInt64(&c.bytesOut, dstHeaderLen)
|
||||
return
|
||||
|
||||
case pkt := <-c.in:
|
||||
atomic.AddInt64(&c.packetsIn, 1)
|
||||
atomic.AddInt64(&c.bytesIn, dstHeaderLen+int64(len(pkt.data)))
|
||||
|
||||
c.expCount = 1
|
||||
|
||||
switch pkt.hdr.packetType {
|
||||
case typeData:
|
||||
c.rcvData(pkt)
|
||||
case typeAck:
|
||||
c.rcvAck(pkt)
|
||||
case typeNegAck:
|
||||
c.rcvNegAck(pkt)
|
||||
case typeShutdown:
|
||||
c.rcvShutdown(pkt)
|
||||
default:
|
||||
log.Println("Unhandled packet", pkt)
|
||||
continue
|
||||
}
|
||||
|
||||
case <-c.exp.C:
|
||||
c.eventExp()
|
||||
c.resetExp()
|
||||
|
||||
case n := <-c.debugResetRecvSeqNo:
|
||||
// Back door for testing
|
||||
c.lastAckedSeqNo = n - 1
|
||||
c.nextRecvSeqNo = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) eventExp() {
|
||||
c.expCount++
|
||||
|
||||
if c.sendBuffer.lost.Len() > 0 || c.sendBuffer.send.Len() > 0 {
|
||||
c.cc.Exp()
|
||||
c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate())
|
||||
c.sendBuffer.ScheduleResend()
|
||||
|
||||
if debugConnection {
|
||||
log.Println(c, "did resends due to Exp")
|
||||
}
|
||||
|
||||
if c.expCount > expCountClose && time.Since(c.expReset) > minTimeClose {
|
||||
if debugConnection {
|
||||
log.Println(c, "close due to Exp")
|
||||
}
|
||||
|
||||
// We're shutting down due to repeated exp:s. Don't wait for the
|
||||
// send buffer to drain, which it would otherwise do in
|
||||
// c.Close()..
|
||||
c.sendBuffer.CrashStop()
|
||||
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rcvAck(pkt packet) {
|
||||
ack := pkt.hdr.sequenceNo
|
||||
|
||||
if debugConnection {
|
||||
log.Printf("%v read Ack %v", c, ack)
|
||||
}
|
||||
|
||||
c.cc.Ack()
|
||||
|
||||
if ack%rttMeasureSample == 0 {
|
||||
if ts := timestamp(binary.BigEndian.Uint32(pkt.data)); ts > 0 {
|
||||
if delay := time.Duration(timestampMicros()-ts) * time.Microsecond; delay > 0 {
|
||||
c.packetDelaysMut.Lock()
|
||||
c.packetDelays[c.packetDelaysSlot] = delay
|
||||
c.packetDelaysSlot = (c.packetDelaysSlot + 1) % len(c.packetDelays)
|
||||
c.packetDelaysMut.Unlock()
|
||||
|
||||
if rtt, n := c.averageDelay(); n > 8 {
|
||||
c.cc.UpdateRTT(rtt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.sendBuffer.Acknowledge(ack)
|
||||
c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate())
|
||||
|
||||
c.resetExp()
|
||||
}
|
||||
|
||||
func (c *Conn) averageDelay() (time.Duration, int) {
|
||||
var total time.Duration
|
||||
var n int
|
||||
|
||||
c.packetDelaysMut.Lock()
|
||||
for _, d := range c.packetDelays {
|
||||
if d != 0 {
|
||||
total += d
|
||||
n++
|
||||
}
|
||||
}
|
||||
c.packetDelaysMut.Unlock()
|
||||
|
||||
if n == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
return total / time.Duration(n), n
|
||||
}
|
||||
|
||||
func (c *Conn) rcvNegAck(pkt packet) {
|
||||
nak := pkt.hdr.sequenceNo
|
||||
|
||||
if debugConnection {
|
||||
log.Printf("%v read NegAck %v", c, nak)
|
||||
}
|
||||
|
||||
c.sendBuffer.NegativeAck(nak)
|
||||
|
||||
//c.cc.NegAck()
|
||||
c.resetExp()
|
||||
}
|
||||
|
||||
func (c *Conn) rcvShutdown(pkt packet) {
|
||||
// XXX: We accept shutdown packets somewhat from the future since the
|
||||
// sender will number the shutdown after any packets that might still be
|
||||
// in the write buffer. This should be fixed to let the write buffer empty
|
||||
// on close and reduce the window here.
|
||||
if pkt.LessSeq(c.nextRecvSeqNo + 128) {
|
||||
if debugConnection {
|
||||
log.Println(c, "close due to shutdown")
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) rcvData(pkt packet) {
|
||||
if debugConnection {
|
||||
log.Println(c, "recv data", pkt.hdr)
|
||||
}
|
||||
|
||||
if pkt.LessSeq(c.nextRecvSeqNo) {
|
||||
if debugConnection {
|
||||
log.Printf("%v old packet received; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo)
|
||||
}
|
||||
atomic.AddInt64(&c.droppedPackets, 1)
|
||||
return
|
||||
}
|
||||
|
||||
if debugConnection {
|
||||
log.Println(c, "into recv buffer:", pkt)
|
||||
}
|
||||
c.recvBuffer.InsertSorted(pkt)
|
||||
if c.recvBuffer.LowestSeq() == c.nextRecvSeqNo {
|
||||
for _, pkt := range c.recvBuffer.PopSequence(^sequenceNo(0)) {
|
||||
if debugConnection {
|
||||
log.Println(c, "from recv buffer:", pkt)
|
||||
}
|
||||
|
||||
// An in-sequence packet.
|
||||
|
||||
c.nextRecvSeqNo = pkt.hdr.sequenceNo + 1
|
||||
|
||||
c.sendAck(pkt.hdr.timestamp)
|
||||
|
||||
c.inbufMut.Lock()
|
||||
for c.inbuf.Len() > len(pkt.data)+maxInputBuffer {
|
||||
c.inbufCond.Wait()
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
c.inbuf.Write(pkt.data)
|
||||
c.inbufCond.Broadcast()
|
||||
c.inbufMut.Unlock()
|
||||
}
|
||||
} else {
|
||||
if debugConnection {
|
||||
log.Printf("%v lost; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo)
|
||||
}
|
||||
c.recvBuffer.InsertSorted(pkt)
|
||||
c.sendNegAck()
|
||||
atomic.AddInt64(&c.outOfOrderPackets, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendAck(ts timestamp) {
|
||||
if c.lastAckedSeqNo == c.nextRecvSeqNo {
|
||||
return
|
||||
}
|
||||
|
||||
var buf [4]byte
|
||||
binary.BigEndian.PutUint32(buf[:], uint32(ts))
|
||||
c.mux.write(packet{
|
||||
src: c.connID,
|
||||
dst: c.dst,
|
||||
hdr: header{
|
||||
packetType: typeAck,
|
||||
connID: c.remoteConnID,
|
||||
sequenceNo: c.nextRecvSeqNo,
|
||||
},
|
||||
data: buf[:],
|
||||
})
|
||||
|
||||
atomic.AddInt64(&c.packetsOut, 1)
|
||||
atomic.AddInt64(&c.bytesOut, dstHeaderLen)
|
||||
if debugConnection {
|
||||
log.Printf("%v send Ack %v", c, c.nextRecvSeqNo)
|
||||
}
|
||||
|
||||
c.lastAckedSeqNo = c.nextRecvSeqNo
|
||||
}
|
||||
|
||||
func (c *Conn) sendNegAck() {
|
||||
if c.lastNegAckedSeqNo == c.nextRecvSeqNo {
|
||||
return
|
||||
}
|
||||
|
||||
c.mux.write(packet{
|
||||
src: c.connID,
|
||||
dst: c.dst,
|
||||
hdr: header{
|
||||
packetType: typeNegAck,
|
||||
connID: c.remoteConnID,
|
||||
sequenceNo: c.nextRecvSeqNo,
|
||||
},
|
||||
})
|
||||
|
||||
atomic.AddInt64(&c.packetsOut, 1)
|
||||
atomic.AddInt64(&c.bytesOut, dstHeaderLen)
|
||||
if debugConnection {
|
||||
log.Printf("%v send NegAck %v", c, c.nextRecvSeqNo)
|
||||
}
|
||||
|
||||
c.lastNegAckedSeqNo = c.nextRecvSeqNo
|
||||
}
|
||||
|
||||
func (c *Conn) resetExp() {
|
||||
d, _ := c.averageDelay()
|
||||
d = d*4 + 10*time.Millisecond
|
||||
|
||||
if d < defExpTime {
|
||||
d = defExpTime
|
||||
}
|
||||
|
||||
c.expMut.Lock()
|
||||
c.exp.Reset(d)
|
||||
c.expMut.Unlock()
|
||||
}
|
||||
|
||||
// String returns a string representation of the connection.
|
||||
func (c *Conn) String() string {
|
||||
return fmt.Sprintf("%v/%v/%v", c.connID, c.LocalAddr(), c.RemoteAddr())
|
||||
}
|
||||
|
||||
// Read reads data from the connection.
|
||||
// Read can be made to time out and return a Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
c.inbufMut.Lock()
|
||||
defer c.inbufMut.Unlock()
|
||||
for c.inbuf.Len() == 0 {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, io.EOF
|
||||
default:
|
||||
}
|
||||
c.inbufCond.Wait()
|
||||
}
|
||||
return c.inbuf.Read(b)
|
||||
}
|
||||
|
||||
// Write writes data to the connection.
|
||||
// Write can be made to time out and return a Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, ErrClosedConn
|
||||
default:
|
||||
}
|
||||
|
||||
sent := 0
|
||||
sliceSize := c.packetSize - sliceOverhead
|
||||
for i := 0; i < len(b); i += sliceSize {
|
||||
nxt := i + sliceSize
|
||||
if nxt > len(b) {
|
||||
nxt = len(b)
|
||||
}
|
||||
slice := b[i:nxt]
|
||||
sliceCopy := c.mux.buffers.Get().([]byte)[:len(slice)]
|
||||
copy(sliceCopy, slice)
|
||||
|
||||
c.nextSeqNoMut.Lock()
|
||||
pkt := packet{
|
||||
src: c.connID,
|
||||
dst: c.dst,
|
||||
hdr: header{
|
||||
packetType: typeData,
|
||||
sequenceNo: c.nextSeqNo,
|
||||
connID: c.remoteConnID,
|
||||
},
|
||||
data: sliceCopy,
|
||||
}
|
||||
c.nextSeqNo++
|
||||
c.nextSeqNoMut.Unlock()
|
||||
|
||||
if err := c.sendBuffer.Write(pkt); err != nil {
|
||||
return sent, err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&c.packetsOut, 1)
|
||||
atomic.AddInt64(&c.bytesOut, int64(len(slice)+dstHeaderLen))
|
||||
|
||||
sent += len(slice)
|
||||
c.resetExp()
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
// Any blocked Read or Write operations will be unblocked and return errors.
|
||||
func (c *Conn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
if debugConnection {
|
||||
log.Println(c, "explicit close start")
|
||||
defer log.Println(c, "explicit close done")
|
||||
}
|
||||
|
||||
// XXX: Ugly hack to implement lingering sockets...
|
||||
time.Sleep(4 * defExpTime)
|
||||
|
||||
c.sendBuffer.Stop()
|
||||
c.mux.removeConn(c)
|
||||
close(c.closed)
|
||||
|
||||
c.inbufMut.Lock()
|
||||
c.inbufCond.Broadcast()
|
||||
c.inbufMut.Unlock()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.mux.Addr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.dst
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines associated
|
||||
// with the connection. It is equivalent to calling both
|
||||
// SetReadDeadline and SetWriteDeadline.
|
||||
//
|
||||
// A deadline is an absolute time after which I/O operations
|
||||
// fail with a timeout (see type Error) instead of
|
||||
// blocking. The deadline applies to all future I/O, not just
|
||||
// the immediately following call to Read or Write.
|
||||
//
|
||||
// An idle timeout can be implemented by repeatedly extending
|
||||
// the deadline after successful Read or Write calls.
|
||||
//
|
||||
// A zero value for t means I/O operations will not time out.
|
||||
//
|
||||
// BUG(jb): SetDeadline is not implemented.
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the deadline for future Read calls.
|
||||
// A zero value for t means Read will not time out.
|
||||
//
|
||||
// BUG(jb): SetReadDeadline is not implemented.
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the deadline for future Write calls.
|
||||
// Even if write times out, it may return n > 0, indicating that
|
||||
// some of the data was successfully written.
|
||||
// A zero value for t means Write will not time out.
|
||||
//
|
||||
// BUG(jb): SetWriteDeadline is not implemented.
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
type Statistics struct {
|
||||
DataPacketsIn int64
|
||||
DataPacketsOut int64
|
||||
DataBytesIn int64
|
||||
DataBytesOut int64
|
||||
ResentPackets int64
|
||||
DroppedPackets int64
|
||||
OutOfOrderPackets int64
|
||||
}
|
||||
|
||||
// String returns a printable represetnation of the Statistics.
|
||||
func (s Statistics) String() string {
|
||||
return fmt.Sprintf("PktsIn: %d, PktsOut: %d, BytesIn: %d, BytesOut: %d, PktsResent: %d, PktsDropped: %d, PktsOutOfOrder: %d",
|
||||
s.DataPacketsIn, s.DataPacketsOut, s.DataBytesIn, s.DataBytesOut, s.ResentPackets, s.DroppedPackets, s.OutOfOrderPackets)
|
||||
}
|
||||
|
||||
// GetStatistics returns a snapsht of the current connection statistics.
|
||||
func (c *Conn) GetStatistics() Statistics {
|
||||
return Statistics{
|
||||
DataPacketsIn: atomic.LoadInt64(&c.packetsIn),
|
||||
DataPacketsOut: atomic.LoadInt64(&c.packetsOut),
|
||||
DataBytesIn: atomic.LoadInt64(&c.bytesIn),
|
||||
DataBytesOut: atomic.LoadInt64(&c.bytesOut),
|
||||
ResentPackets: atomic.LoadInt64(&c.resentPackets),
|
||||
DroppedPackets: atomic.LoadInt64(&c.droppedPackets),
|
||||
OutOfOrderPackets: atomic.LoadInt64(&c.outOfOrderPackets),
|
||||
}
|
||||
}
|
||||
29
core/dst/cookie.go
Normal file
29
core/dst/cookie.go
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
)
|
||||
|
||||
var cookieKey = make([]byte, 16)
|
||||
|
||||
func init() {
|
||||
_, err := rand.Reader.Read(cookieKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func cookie(remote net.Addr) uint32 {
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(remote.String()))
|
||||
hash.Write(cookieKey)
|
||||
bs := hash.Sum(nil)
|
||||
return binary.BigEndian.Uint32(bs)
|
||||
}
|
||||
26
core/dst/debug.go
Normal file
26
core/dst/debug.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
debugConnection bool
|
||||
debugMux bool
|
||||
debugCC bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
debug := make(map[string]bool)
|
||||
for _, s := range strings.Split(os.Getenv("DSTDEBUG"), ",") {
|
||||
debug[strings.TrimSpace(s)] = true
|
||||
}
|
||||
debugConnection = debug["conn"]
|
||||
debugMux = debug["mux"]
|
||||
debugCC = debug["cc"]
|
||||
}
|
||||
12
core/dst/doc.go
Normal file
12
core/dst/doc.go
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
|
||||
Package dst implements the Datagram Stream Transfer protocol.
|
||||
|
||||
DST is a way to get reliable stream connections (like TCP) on top of UDP.
|
||||
|
||||
*/
|
||||
package dst
|
||||
23
core/dst/errors.go
Normal file
23
core/dst/errors.go
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
// Error represents the various dst-internal error conditions.
|
||||
type Error struct {
|
||||
Err string
|
||||
}
|
||||
|
||||
// Error returns a string representation of the error.
|
||||
func (e Error) Error() string {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
var (
|
||||
ErrClosedConn = &Error{"operation on closed connection"}
|
||||
ErrClosedMux = &Error{"operation on closed mux"}
|
||||
ErrHandshakeTimeout = &Error{"handshake timeout"}
|
||||
ErrNotDST = &Error{"network is not dst"}
|
||||
ErrNotImplemented = &Error{"not implemented"}
|
||||
)
|
||||
422
core/dst/mux.go
Normal file
422
core/dst/mux.go
Normal file
@ -0,0 +1,422 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
maxIncomingRequests = 1024
|
||||
maxPacketSize = 500
|
||||
handshakeTimeout = 5 * time.Second
|
||||
handshakeInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// Mux is a UDP multiplexer of DST connections.
|
||||
type Mux struct {
|
||||
conn net.PacketConn
|
||||
packetSize int
|
||||
|
||||
conns map[connectionID]*Conn
|
||||
handshakes map[connectionID]chan packet
|
||||
connsMut sync.Mutex
|
||||
|
||||
incoming chan *Conn
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
|
||||
buffers *sync.Pool
|
||||
}
|
||||
|
||||
// NewMux creates a new DST Mux on top of a packet connection.
|
||||
func NewMux(conn net.PacketConn, packetSize int) *Mux {
|
||||
if packetSize <= 0 {
|
||||
packetSize = maxPacketSize
|
||||
}
|
||||
m := &Mux{
|
||||
conn: conn,
|
||||
packetSize: packetSize,
|
||||
conns: map[connectionID]*Conn{},
|
||||
handshakes: make(map[connectionID]chan packet),
|
||||
incoming: make(chan *Conn, maxIncomingRequests),
|
||||
closed: make(chan struct{}),
|
||||
buffers: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, packetSize)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Attempt to maximize buffer space. Start at 16 MB and work downwards 0.5
|
||||
// MB at a time.
|
||||
|
||||
if conn, ok := conn.(*net.UDPConn); ok {
|
||||
for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
|
||||
err := conn.SetReadBuffer(buf)
|
||||
if err == nil {
|
||||
if debugMux {
|
||||
log.Println(m, "read buffer is", buf)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
|
||||
err := conn.SetWriteBuffer(buf)
|
||||
if err == nil {
|
||||
if debugMux {
|
||||
log.Println(m, "write buffer is", buf)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go m.readerLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (m *Mux) Accept() (net.Conn, error) {
|
||||
return m.AcceptDST()
|
||||
}
|
||||
|
||||
// AcceptDST waits for and returns the next connection to the listener.
|
||||
func (m *Mux) AcceptDST() (*Conn, error) {
|
||||
conn, ok := <-m.incoming
|
||||
if !ok {
|
||||
return nil, ErrClosedMux
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
func (m *Mux) Close() error {
|
||||
var err error = ErrClosedMux
|
||||
m.closeOnce.Do(func() {
|
||||
err = m.conn.Close()
|
||||
close(m.incoming)
|
||||
close(m.closed)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (m *Mux) Addr() net.Addr {
|
||||
return m.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// Dial connects to the address on the named network.
|
||||
//
|
||||
// Network must be "dst".
|
||||
//
|
||||
// Addresses have the form host:port. If host is a literal IPv6 address or
|
||||
// host name, it must be enclosed in square brackets as in "[::1]:80",
|
||||
// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
|
||||
// SplitHostPort manipulate addresses in this form.
|
||||
//
|
||||
// Examples:
|
||||
// Dial("dst", "12.34.56.78:80")
|
||||
// Dial("dst", "google.com:http")
|
||||
// Dial("dst", "[2001:db8::1]:http")
|
||||
// Dial("dst", "[fe80::1%lo0]:80")
|
||||
func (m *Mux) Dial(network, addr string) (net.Conn, error) {
|
||||
return m.DialDST(network, addr)
|
||||
}
|
||||
|
||||
// Dial connects to the address on the named network.
|
||||
//
|
||||
// Network must be "dst".
|
||||
//
|
||||
// Addresses have the form host:port. If host is a literal IPv6 address or
|
||||
// host name, it must be enclosed in square brackets as in "[::1]:80",
|
||||
// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
|
||||
// SplitHostPort manipulate addresses in this form.
|
||||
//
|
||||
// Examples:
|
||||
// Dial("dst", "12.34.56.78:80")
|
||||
// Dial("dst", "google.com:http")
|
||||
// Dial("dst", "[2001:db8::1]:http")
|
||||
// Dial("dst", "[fe80::1%lo0]:80")
|
||||
func (m *Mux) DialDST(network, addr string) (*Conn, error) {
|
||||
if network != "dst" {
|
||||
return nil, ErrNotDST
|
||||
}
|
||||
|
||||
dst, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := make(chan packet)
|
||||
|
||||
m.connsMut.Lock()
|
||||
connID := m.newConnID()
|
||||
m.handshakes[connID] = resp
|
||||
m.connsMut.Unlock()
|
||||
|
||||
conn, err := m.clientHandshake(dst, connID, resp)
|
||||
|
||||
m.connsMut.Lock()
|
||||
defer m.connsMut.Unlock()
|
||||
delete(m.handshakes, connID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.conns[connID] = conn
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handshake performs the client side handshake (i.e. Dial)
|
||||
func (m *Mux) clientHandshake(dst net.Addr, connID connectionID, resp chan packet) (*Conn, error) {
|
||||
if debugMux {
|
||||
log.Printf("%v dial %v connID %v", m, dst, connID)
|
||||
}
|
||||
|
||||
nextHandshake := time.NewTimer(0)
|
||||
defer nextHandshake.Stop()
|
||||
|
||||
handshakeTimeout := time.NewTimer(handshakeTimeout)
|
||||
defer handshakeTimeout.Stop()
|
||||
|
||||
var remoteCookie uint32
|
||||
seqNo := randomSeqNo()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.closed:
|
||||
// Failure. The mux has been closed.
|
||||
return nil, ErrClosedConn
|
||||
|
||||
case <-handshakeTimeout.C:
|
||||
// Handshake timeout. Close and abort.
|
||||
return nil, ErrHandshakeTimeout
|
||||
|
||||
case <-nextHandshake.C:
|
||||
// Send a handshake request.
|
||||
|
||||
m.write(packet{
|
||||
src: connID,
|
||||
dst: dst,
|
||||
hdr: header{
|
||||
packetType: typeHandshake,
|
||||
flags: flagRequest,
|
||||
connID: 0,
|
||||
sequenceNo: seqNo,
|
||||
timestamp: timestampMicros(),
|
||||
},
|
||||
data: handshakeData{uint32(m.packetSize), connID, remoteCookie}.marshal(),
|
||||
})
|
||||
nextHandshake.Reset(handshakeInterval)
|
||||
|
||||
case pkt := <-resp:
|
||||
hd := unmarshalHandshakeData(pkt.data)
|
||||
|
||||
if pkt.hdr.flags&flagCookie == flagCookie {
|
||||
// We should resend the handshake request with a different cookie value.
|
||||
remoteCookie = hd.cookie
|
||||
nextHandshake.Reset(0)
|
||||
} else if pkt.hdr.flags&flagResponse == flagResponse {
|
||||
// Successfull handshake response.
|
||||
conn := newConn(m, dst)
|
||||
|
||||
conn.connID = connID
|
||||
conn.remoteConnID = hd.connID
|
||||
conn.nextRecvSeqNo = pkt.hdr.sequenceNo + 1
|
||||
conn.packetSize = int(hd.packetSize)
|
||||
if conn.packetSize > m.packetSize {
|
||||
conn.packetSize = m.packetSize
|
||||
}
|
||||
|
||||
conn.nextSeqNo = seqNo + 1
|
||||
|
||||
conn.start()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) readerLoop() {
|
||||
buf := make([]byte, m.packetSize)
|
||||
for {
|
||||
buf = buf[:cap(buf)]
|
||||
n, from, err := m.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
m.Close()
|
||||
return
|
||||
}
|
||||
buf = buf[:n]
|
||||
|
||||
hdr := unmarshalHeader(buf)
|
||||
|
||||
var bufCopy []byte
|
||||
if len(buf) > dstHeaderLen {
|
||||
bufCopy = m.buffers.Get().([]byte)[:len(buf)-dstHeaderLen]
|
||||
copy(bufCopy, buf[dstHeaderLen:])
|
||||
}
|
||||
|
||||
pkt := packet{hdr: hdr, data: bufCopy}
|
||||
if debugMux {
|
||||
log.Println(m, "read", pkt)
|
||||
}
|
||||
|
||||
if hdr.packetType == typeHandshake {
|
||||
m.incomingHandshake(from, hdr, bufCopy)
|
||||
} else {
|
||||
m.connsMut.Lock()
|
||||
conn, ok := m.conns[hdr.connID]
|
||||
m.connsMut.Unlock()
|
||||
|
||||
if ok {
|
||||
conn.in <- packet{
|
||||
dst: nil,
|
||||
hdr: hdr,
|
||||
data: bufCopy,
|
||||
}
|
||||
} else if debugMux && hdr.packetType != typeShutdown {
|
||||
log.Printf("packet %v for unknown conn %v", hdr, hdr.connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) incomingHandshake(from net.Addr, hdr header, data []byte) {
|
||||
if hdr.connID == 0 {
|
||||
// A new incoming handshake request.
|
||||
m.incomingHandshakeRequest(from, hdr, data)
|
||||
} else {
|
||||
// A response to an ongoing handshake.
|
||||
m.incomingHandshakeResponse(from, hdr, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) incomingHandshakeRequest(from net.Addr, hdr header, data []byte) {
|
||||
if hdr.flags&flagRequest != flagRequest {
|
||||
log.Printf("Handshake pattern with flags 0x%x to connID zero", hdr.flags)
|
||||
return
|
||||
}
|
||||
|
||||
hd := unmarshalHandshakeData(data)
|
||||
|
||||
correctCookie := cookie(from)
|
||||
if hd.cookie != correctCookie {
|
||||
// Incorrect or missing SYN cookie. Send back a handshake
|
||||
// with the expected one.
|
||||
m.write(packet{
|
||||
dst: from,
|
||||
hdr: header{
|
||||
packetType: typeHandshake,
|
||||
flags: flagResponse | flagCookie,
|
||||
connID: hd.connID,
|
||||
timestamp: timestampMicros(),
|
||||
},
|
||||
data: handshakeData{
|
||||
packetSize: uint32(m.packetSize),
|
||||
cookie: correctCookie,
|
||||
}.marshal(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
seqNo := randomSeqNo()
|
||||
|
||||
m.connsMut.Lock()
|
||||
connID := m.newConnID()
|
||||
|
||||
conn := newConn(m, from)
|
||||
conn.connID = connID
|
||||
conn.remoteConnID = hd.connID
|
||||
conn.nextSeqNo = seqNo + 1
|
||||
conn.nextRecvSeqNo = hdr.sequenceNo + 1
|
||||
conn.packetSize = int(hd.packetSize)
|
||||
if conn.packetSize > m.packetSize {
|
||||
conn.packetSize = m.packetSize
|
||||
}
|
||||
conn.start()
|
||||
|
||||
m.conns[connID] = conn
|
||||
m.connsMut.Unlock()
|
||||
|
||||
m.write(packet{
|
||||
dst: from,
|
||||
hdr: header{
|
||||
packetType: typeHandshake,
|
||||
flags: flagResponse,
|
||||
connID: hd.connID,
|
||||
sequenceNo: seqNo,
|
||||
timestamp: timestampMicros(),
|
||||
},
|
||||
data: handshakeData{
|
||||
connID: conn.connID,
|
||||
packetSize: uint32(conn.packetSize),
|
||||
}.marshal(),
|
||||
})
|
||||
|
||||
m.incoming <- conn
|
||||
}
|
||||
|
||||
func (m *Mux) incomingHandshakeResponse(from net.Addr, hdr header, data []byte) {
|
||||
m.connsMut.Lock()
|
||||
handShake, ok := m.handshakes[hdr.connID]
|
||||
m.connsMut.Unlock()
|
||||
|
||||
if ok {
|
||||
// This is a response to a handshake in progress.
|
||||
handShake <- packet{
|
||||
dst: nil,
|
||||
hdr: hdr,
|
||||
data: data,
|
||||
}
|
||||
} else if debugMux && hdr.packetType != typeShutdown {
|
||||
log.Printf("Handshake packet %v for unknown conn %v", hdr, hdr.connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) write(pkt packet) (int, error) {
|
||||
buf := m.buffers.Get().([]byte)
|
||||
buf = buf[:dstHeaderLen+len(pkt.data)]
|
||||
pkt.hdr.marshal(buf)
|
||||
copy(buf[dstHeaderLen:], pkt.data)
|
||||
if debugMux {
|
||||
log.Println(m, "write", pkt)
|
||||
}
|
||||
n, err := m.conn.WriteTo(buf, pkt.dst)
|
||||
m.buffers.Put(buf)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (m *Mux) String() string {
|
||||
return fmt.Sprintf("Mux-%v", m.Addr())
|
||||
}
|
||||
|
||||
// Find a unique connection ID
|
||||
func (m *Mux) newConnID() connectionID {
|
||||
for {
|
||||
connID := randomConnID()
|
||||
if _, ok := m.conns[connID]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := m.handshakes[connID]; ok {
|
||||
continue
|
||||
}
|
||||
return connID
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) removeConn(c *Conn) {
|
||||
m.connsMut.Lock()
|
||||
delete(m.conns, c.connID)
|
||||
m.connsMut.Unlock()
|
||||
}
|
||||
119
core/dst/packetlist.go
Normal file
119
core/dst/packetlist.go
Normal file
@ -0,0 +1,119 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
type packetList struct {
|
||||
packets []packet
|
||||
slot int
|
||||
}
|
||||
|
||||
// CutLessSeq cuts packets from the start of the list with sequence numbers
|
||||
// lower than seq. Returns the number of packets that were cut.
|
||||
func (l *packetList) CutLessSeq(seq sequenceNo) int {
|
||||
var i, cut int
|
||||
for i = range l.packets {
|
||||
if i == l.slot {
|
||||
break
|
||||
}
|
||||
if !l.packets[i].LessSeq(seq) {
|
||||
break
|
||||
}
|
||||
cut++
|
||||
}
|
||||
if cut > 0 {
|
||||
l.Cut(cut)
|
||||
}
|
||||
return cut
|
||||
}
|
||||
|
||||
func (l *packetList) Cut(n int) {
|
||||
copy(l.packets, l.packets[n:])
|
||||
l.slot -= n
|
||||
}
|
||||
|
||||
func (l *packetList) Full() bool {
|
||||
return l.slot == len(l.packets)
|
||||
}
|
||||
|
||||
func (l *packetList) All() []packet {
|
||||
return l.packets[:l.slot]
|
||||
}
|
||||
|
||||
func (l *packetList) Append(pkt packet) bool {
|
||||
if l.slot == len(l.packets) {
|
||||
return false
|
||||
}
|
||||
l.packets[l.slot] = pkt
|
||||
l.slot++
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *packetList) AppendAll(pkts []packet) {
|
||||
l.packets = append(l.packets[:l.slot], pkts...)
|
||||
l.slot += len(pkts)
|
||||
}
|
||||
|
||||
func (l *packetList) Cap() int {
|
||||
return len(l.packets)
|
||||
}
|
||||
|
||||
func (l *packetList) Len() int {
|
||||
return l.slot
|
||||
}
|
||||
|
||||
func (l *packetList) Resize(s int) {
|
||||
if s <= cap(l.packets) {
|
||||
l.packets = l.packets[:s]
|
||||
} else {
|
||||
t := make([]packet, s)
|
||||
copy(t, l.packets)
|
||||
l.packets = t
|
||||
}
|
||||
}
|
||||
|
||||
func (l *packetList) InsertSorted(pkt packet) {
|
||||
for i := range l.packets {
|
||||
if i >= l.slot {
|
||||
l.packets[i] = pkt
|
||||
l.slot++
|
||||
return
|
||||
}
|
||||
if pkt.hdr.sequenceNo == l.packets[i].hdr.sequenceNo {
|
||||
return
|
||||
}
|
||||
if pkt.Less(l.packets[i]) {
|
||||
copy(l.packets[i+1:], l.packets[i:])
|
||||
l.packets[i] = pkt
|
||||
if l.slot < len(l.packets) {
|
||||
l.slot++
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *packetList) LowestSeq() sequenceNo {
|
||||
return l.packets[0].hdr.sequenceNo
|
||||
}
|
||||
|
||||
func (l *packetList) PopSequence(maxSeq sequenceNo) []packet {
|
||||
highSeq := l.packets[0].hdr.sequenceNo
|
||||
if highSeq >= maxSeq {
|
||||
return nil
|
||||
}
|
||||
|
||||
var i int
|
||||
for i = 1; i < l.slot; i++ {
|
||||
seq := l.packets[i].hdr.sequenceNo
|
||||
if seq != highSeq+1 || seq >= maxSeq {
|
||||
break
|
||||
}
|
||||
highSeq++
|
||||
}
|
||||
pkts := make([]packet, i)
|
||||
copy(pkts, l.packets[:i])
|
||||
l.Cut(i)
|
||||
return pkts
|
||||
}
|
||||
155
core/dst/packets.go
Normal file
155
core/dst/packets.go
Normal file
@ -0,0 +1,155 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const dstHeaderLen = 12
|
||||
|
||||
type packetType int8
|
||||
|
||||
const (
|
||||
typeHandshake packetType = 0x0
|
||||
typeData = 0x1
|
||||
typeAck = 0x2
|
||||
typeNegAck = 0x3
|
||||
typeShutdown = 0x4
|
||||
)
|
||||
|
||||
func (t packetType) String() string {
|
||||
switch t {
|
||||
case typeData:
|
||||
return "data"
|
||||
case typeHandshake:
|
||||
return "handshake"
|
||||
case typeAck:
|
||||
return "ack"
|
||||
case typeNegAck:
|
||||
return "negAck"
|
||||
case typeShutdown:
|
||||
return "shutdown"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type connectionID uint32
|
||||
|
||||
func (c connectionID) String() string {
|
||||
return fmt.Sprintf("Ci%08x", uint32(c))
|
||||
}
|
||||
|
||||
type sequenceNo uint32
|
||||
|
||||
func (s sequenceNo) String() string {
|
||||
return fmt.Sprintf("Sq%d", uint32(s))
|
||||
}
|
||||
|
||||
type timestamp uint32
|
||||
|
||||
func (t timestamp) String() string {
|
||||
return fmt.Sprintf("Ts%d", uint32(t))
|
||||
}
|
||||
|
||||
const (
|
||||
flagRequest = 1 << 0 // This packet is a handshake request
|
||||
flagResponse = 1 << 1 // This packet is a handshake response
|
||||
flagCookie = 1 << 2 // This packet contains a coookie challenge
|
||||
)
|
||||
|
||||
type header struct {
|
||||
packetType packetType // 4 bits
|
||||
flags uint8 // 4 bits
|
||||
connID connectionID // 24 bits
|
||||
sequenceNo sequenceNo
|
||||
timestamp timestamp
|
||||
}
|
||||
|
||||
func (h header) marshal(bs []byte) {
|
||||
binary.BigEndian.PutUint32(bs, uint32(h.connID&0xffffff))
|
||||
bs[0] = h.flags | uint8(h.packetType)<<4
|
||||
binary.BigEndian.PutUint32(bs[4:], uint32(h.sequenceNo))
|
||||
binary.BigEndian.PutUint32(bs[8:], uint32(h.timestamp))
|
||||
}
|
||||
|
||||
func unmarshalHeader(bs []byte) header {
|
||||
var h header
|
||||
h.packetType = packetType(bs[0] >> 4)
|
||||
h.flags = bs[0] & 0xf
|
||||
h.connID = connectionID(binary.BigEndian.Uint32(bs) & 0xffffff)
|
||||
h.sequenceNo = sequenceNo(binary.BigEndian.Uint32(bs[4:]))
|
||||
h.timestamp = timestamp(binary.BigEndian.Uint32(bs[8:]))
|
||||
return h
|
||||
}
|
||||
|
||||
func (h header) String() string {
|
||||
return fmt.Sprintf("header{type=%s flags=0x%x connID=%v seq=%v time=%v}", h.packetType, h.flags, h.connID, h.sequenceNo, h.timestamp)
|
||||
}
|
||||
|
||||
type handshakeData struct {
|
||||
packetSize uint32
|
||||
connID connectionID
|
||||
cookie uint32
|
||||
}
|
||||
|
||||
func (h handshakeData) marshalInto(data []byte) {
|
||||
binary.BigEndian.PutUint32(data[0:], h.packetSize)
|
||||
binary.BigEndian.PutUint32(data[4:], uint32(h.connID))
|
||||
binary.BigEndian.PutUint32(data[8:], h.cookie)
|
||||
}
|
||||
|
||||
func (h handshakeData) marshal() []byte {
|
||||
var data [12]byte
|
||||
h.marshalInto(data[:])
|
||||
return data[:]
|
||||
}
|
||||
|
||||
func unmarshalHandshakeData(data []byte) handshakeData {
|
||||
var h handshakeData
|
||||
h.packetSize = binary.BigEndian.Uint32(data[0:])
|
||||
h.connID = connectionID(binary.BigEndian.Uint32(data[4:]))
|
||||
h.cookie = binary.BigEndian.Uint32(data[8:])
|
||||
return h
|
||||
}
|
||||
|
||||
func (h handshakeData) String() string {
|
||||
return fmt.Sprintf("handshake{size=%d connID=%v cookie=0x%08x}", h.packetSize, h.connID, h.cookie)
|
||||
}
|
||||
|
||||
type packet struct {
|
||||
src connectionID
|
||||
dst net.Addr
|
||||
hdr header
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (p packet) String() string {
|
||||
var dst string
|
||||
if p.dst != nil {
|
||||
dst = "dst=" + p.dst.String() + " "
|
||||
}
|
||||
switch p.hdr.packetType {
|
||||
case typeHandshake:
|
||||
return fmt.Sprintf("%spacket{src=%v %v %v}", dst, p.src, p.hdr, unmarshalHandshakeData(p.data))
|
||||
default:
|
||||
return fmt.Sprintf("%spacket{src=%v %v data[:%d]}", dst, p.src, p.hdr, len(p.data))
|
||||
}
|
||||
}
|
||||
|
||||
func (p packet) LessSeq(seq sequenceNo) bool {
|
||||
diff := seq - p.hdr.sequenceNo
|
||||
if diff == 0 {
|
||||
return false
|
||||
}
|
||||
return diff < 1<<31
|
||||
}
|
||||
|
||||
func (a packet) Less(b packet) bool {
|
||||
return a.LessSeq(b.hdr.sequenceNo)
|
||||
}
|
||||
260
core/dst/sendbuffer.go
Normal file
260
core/dst/sendbuffer.go
Normal file
@ -0,0 +1,260 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"sync"
|
||||
|
||||
"github.com/juju/ratelimit"
|
||||
)
|
||||
|
||||
/*
|
||||
sendWindow
|
||||
v
|
||||
[S|S|S|S|Q|Q|Q|Q| | | | | | | | | ]
|
||||
^ ^writeSlot
|
||||
sendSlot
|
||||
*/
|
||||
type sendBuffer struct {
|
||||
mux *Mux // we send packets here
|
||||
scheduler *ratelimit.Bucket // sets send rate for packets
|
||||
|
||||
sendWindow int // maximum number of outstanding non-acked packets
|
||||
packetRate int // target pps
|
||||
|
||||
send packetList // buffered packets
|
||||
sendSlot int // buffer slot from which to send next packet
|
||||
|
||||
lost packetList // list of packets reported lost by timeout
|
||||
lostSlot int // next lost packet to resend
|
||||
|
||||
closed bool
|
||||
closing bool
|
||||
mut sync.Mutex
|
||||
cond *sync.Cond
|
||||
}
|
||||
|
||||
const (
|
||||
schedulerRate = 1e6
|
||||
schedulerCapacity = schedulerRate / 40
|
||||
)
|
||||
|
||||
// newSendBuffer creates a new send buffer with a zero window.
|
||||
// SetRateAndWindow() must be called to set an initial packet rate and send
|
||||
// window before using.
|
||||
func newSendBuffer(m *Mux) *sendBuffer {
|
||||
b := &sendBuffer{
|
||||
mux: m,
|
||||
scheduler: ratelimit.NewBucketWithRate(schedulerRate, schedulerCapacity),
|
||||
}
|
||||
b.cond = sync.NewCond(&b.mut)
|
||||
go b.writerLoop()
|
||||
return b
|
||||
}
|
||||
|
||||
// Write puts a new packet in send buffer and schedules a send. Blocks when
|
||||
// the window size is or would be exceeded.
|
||||
func (b *sendBuffer) Write(pkt packet) error {
|
||||
b.mut.Lock()
|
||||
defer b.mut.Unlock()
|
||||
|
||||
for b.send.Full() || b.send.Len() >= b.sendWindow {
|
||||
if b.closing {
|
||||
return ErrClosedConn
|
||||
}
|
||||
if debugConnection {
|
||||
log.Println(b, "Write blocked")
|
||||
}
|
||||
b.cond.Wait()
|
||||
}
|
||||
if !b.send.Append(pkt) {
|
||||
panic("bug: append failed")
|
||||
}
|
||||
b.cond.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Acknowledge removes packets with lower sequence numbers from the loss list
|
||||
// or send buffer.
|
||||
func (b *sendBuffer) Acknowledge(seq sequenceNo) {
|
||||
b.mut.Lock()
|
||||
|
||||
if cut := b.lost.CutLessSeq(seq); cut > 0 {
|
||||
if debugConnection {
|
||||
log.Println(b, "cut", cut, "from loss list")
|
||||
}
|
||||
// Next resend should always start with the first packet, regardless
|
||||
// of what we might already have resent previously.
|
||||
b.lostSlot = 0
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
|
||||
if cut := b.send.CutLessSeq(seq); cut > 0 {
|
||||
if debugConnection {
|
||||
log.Println(b, "cut", cut, "from send list")
|
||||
}
|
||||
b.sendSlot -= cut
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
func (b *sendBuffer) NegativeAck(seq sequenceNo) {
|
||||
b.mut.Lock()
|
||||
|
||||
pkts := b.send.PopSequence(seq)
|
||||
if cut := len(pkts); cut > 0 {
|
||||
b.lost.AppendAll(pkts)
|
||||
if debugConnection {
|
||||
log.Println(b, "cut", cut, "from send list, adding to loss list")
|
||||
log.Println(seq, pkts)
|
||||
}
|
||||
b.sendSlot -= cut
|
||||
b.lostSlot = 0
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
// ScheduleResend arranges for a resend of all currently unacknowledged
|
||||
// packets.
|
||||
func (b *sendBuffer) ScheduleResend() {
|
||||
b.mut.Lock()
|
||||
|
||||
if b.sendSlot > 0 {
|
||||
// There are packets that have been sent but not acked. Move them from
|
||||
// the send buffer to the loss list for retransmission.
|
||||
if debugConnection {
|
||||
log.Println(b, "scheduled resend from send list", b.sendSlot)
|
||||
}
|
||||
|
||||
// Append the packets to the loss list and rewind the send buffer
|
||||
b.lost.AppendAll(b.send.All()[:b.sendSlot])
|
||||
b.send.Cut(b.sendSlot)
|
||||
b.sendSlot = 0
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
|
||||
if b.lostSlot > 0 {
|
||||
// Also resend whatever was already in the loss list
|
||||
if debugConnection {
|
||||
log.Println(b, "scheduled resend from loss list", b.lostSlot)
|
||||
}
|
||||
b.lostSlot = 0
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
// SetWindowAndRate sets the window size (in packets) and packet rate (in
|
||||
// packets per second) to use when sending.
|
||||
func (b *sendBuffer) SetWindowAndRate(sendWindow, packetRate int) {
|
||||
b.mut.Lock()
|
||||
if debugConnection {
|
||||
log.Println(b, "new window & rate", sendWindow, packetRate)
|
||||
}
|
||||
b.packetRate = packetRate
|
||||
b.sendWindow = sendWindow
|
||||
if b.sendWindow > b.send.Cap() {
|
||||
b.send.Resize(b.sendWindow)
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
// Stop stops the send buffer from any doing further sending, but waits for
|
||||
// the current buffers to be drained.
|
||||
func (b *sendBuffer) Stop() {
|
||||
b.mut.Lock()
|
||||
|
||||
if b.closed || b.closing {
|
||||
return
|
||||
}
|
||||
|
||||
b.closing = true
|
||||
for b.lost.Len() > 0 || b.send.Len() > 0 {
|
||||
b.cond.Wait()
|
||||
}
|
||||
|
||||
b.closed = true
|
||||
b.cond.Broadcast()
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
// CrashStop stops the send buffer from any doing further sending, without
|
||||
// waiting for buffers to drain.
|
||||
func (b *sendBuffer) CrashStop() {
|
||||
b.mut.Lock()
|
||||
|
||||
if b.closed || b.closing {
|
||||
return
|
||||
}
|
||||
|
||||
b.closing = true
|
||||
b.closed = true
|
||||
b.cond.Broadcast()
|
||||
b.mut.Unlock()
|
||||
}
|
||||
|
||||
func (b *sendBuffer) String() string {
|
||||
return fmt.Sprintf("sendBuffer@%p", b)
|
||||
}
|
||||
|
||||
func (b *sendBuffer) writerLoop() {
|
||||
if debugConnection {
|
||||
log.Println(b, "writer() starting")
|
||||
defer log.Println(b, "writer() exiting")
|
||||
}
|
||||
|
||||
b.scheduler.Take(schedulerCapacity)
|
||||
for {
|
||||
var pkt packet
|
||||
b.mut.Lock()
|
||||
for b.lostSlot >= b.sendWindow ||
|
||||
(b.sendSlot == b.send.Len() && b.lostSlot == b.lost.Len()) {
|
||||
if b.closed {
|
||||
b.mut.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if debugConnection {
|
||||
log.Println(b, "writer() paused", b.lostSlot, b.sendSlot, b.sendWindow, b.lost.Len())
|
||||
}
|
||||
b.cond.Wait()
|
||||
}
|
||||
|
||||
if b.lostSlot < b.lost.Len() {
|
||||
pkt = b.lost.All()[b.lostSlot]
|
||||
pkt.hdr.timestamp = timestampMicros()
|
||||
b.lostSlot++
|
||||
|
||||
if debugConnection {
|
||||
log.Println(b, "resend", b.lostSlot, b.lost.Len(), b.sendWindow, pkt.hdr.connID, pkt.hdr.sequenceNo)
|
||||
}
|
||||
} else if b.sendSlot < b.send.Len() {
|
||||
pkt = b.send.All()[b.sendSlot]
|
||||
pkt.hdr.timestamp = timestampMicros()
|
||||
b.sendSlot++
|
||||
|
||||
if debugConnection {
|
||||
log.Println(b, "send", b.sendSlot, b.send.Len(), b.sendWindow, pkt.hdr.connID, pkt.hdr.sequenceNo)
|
||||
}
|
||||
}
|
||||
|
||||
b.cond.Broadcast()
|
||||
packetRate := b.packetRate
|
||||
b.mut.Unlock()
|
||||
|
||||
if pkt.dst != nil {
|
||||
b.scheduler.Wait(schedulerRate / int64(packetRate))
|
||||
b.mux.write(pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
29
core/dst/util.go
Normal file
29
core/dst/util.go
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
logger "log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var log = logger.New(os.Stderr, "", logger.LstdFlags)
|
||||
|
||||
func SetLogger(l *logger.Logger) {
|
||||
log = l
|
||||
}
|
||||
func timestampMicros() timestamp {
|
||||
return timestamp(time.Now().UnixNano() / 1000)
|
||||
}
|
||||
|
||||
func randomSeqNo() sequenceNo {
|
||||
return sequenceNo(rand.Uint32())
|
||||
}
|
||||
|
||||
func randomConnID() connectionID {
|
||||
return connectionID(rand.Uint32() & 0xffffff)
|
||||
}
|
||||
144
core/dst/windowcc.go
Normal file
144
core/dst/windowcc.go
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright 2014 The DST Authors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dst
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type windowCC struct {
|
||||
minWindow int
|
||||
maxWindow int
|
||||
currentWindow int
|
||||
minRate int
|
||||
maxRate int
|
||||
currentRate int
|
||||
targetRate int
|
||||
|
||||
curRTT time.Duration
|
||||
minRTT time.Duration
|
||||
|
||||
statsFile io.WriteCloser
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newWindowCC() *windowCC {
|
||||
var statsFile io.WriteCloser
|
||||
|
||||
if debugCC {
|
||||
statsFile, _ = os.Create(fmt.Sprintf("cc-log-%d.csv", time.Now().Unix()))
|
||||
fmt.Fprintf(statsFile, "ms,minWin,maxWin,curWin,minRate,maxRate,curRate,minRTT,curRTT\n")
|
||||
}
|
||||
|
||||
return &windowCC{
|
||||
minWindow: 1, // Packets
|
||||
maxWindow: 16 << 10,
|
||||
currentWindow: 1,
|
||||
|
||||
minRate: 100, // PPS
|
||||
maxRate: 80e3, // Roughly 1 Gbps at 1500 bytes per packet
|
||||
currentRate: 100,
|
||||
targetRate: 1000,
|
||||
|
||||
minRTT: 10 * time.Second,
|
||||
statsFile: statsFile,
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *windowCC) Ack() {
|
||||
if w.curRTT > w.minRTT+100*time.Millisecond {
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
|
||||
if w.currentWindow < w.maxWindow {
|
||||
w.currentWindow++
|
||||
changed = true
|
||||
}
|
||||
|
||||
if w.currentRate != w.targetRate {
|
||||
w.currentRate = (w.currentRate*7 + w.targetRate) / 8
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed && debugCC {
|
||||
w.log()
|
||||
log.Println("Ack", w.currentWindow, w.currentRate)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *windowCC) NegAck() {
|
||||
if w.currentWindow > w.minWindow {
|
||||
w.currentWindow /= 2
|
||||
}
|
||||
if w.currentRate > w.minRate {
|
||||
w.currentRate /= 2
|
||||
}
|
||||
if debugCC {
|
||||
w.log()
|
||||
log.Println("NegAck", w.currentWindow, w.currentRate)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *windowCC) Exp() {
|
||||
w.currentWindow = w.minWindow
|
||||
if debugCC {
|
||||
w.log()
|
||||
log.Println("Exp", w.currentWindow, w.currentRate)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *windowCC) SendWindow() int {
|
||||
if w.currentWindow < w.minWindow {
|
||||
return w.minWindow
|
||||
}
|
||||
if w.currentWindow > w.maxWindow {
|
||||
return w.maxWindow
|
||||
}
|
||||
return w.currentWindow
|
||||
}
|
||||
|
||||
func (w *windowCC) PacketRate() int {
|
||||
if w.currentRate < w.minRate {
|
||||
return w.minRate
|
||||
}
|
||||
if w.currentRate > w.maxRate {
|
||||
return w.maxRate
|
||||
}
|
||||
return w.currentRate
|
||||
}
|
||||
|
||||
func (w *windowCC) UpdateRTT(rtt time.Duration) {
|
||||
w.curRTT = rtt
|
||||
if w.curRTT < w.minRTT {
|
||||
w.minRTT = w.curRTT
|
||||
if debugCC {
|
||||
log.Println("Min RTT", w.minRTT)
|
||||
}
|
||||
}
|
||||
|
||||
if w.curRTT > w.minRTT+200*time.Millisecond && w.targetRate > 2*w.minRate {
|
||||
w.targetRate -= w.minRate
|
||||
} else if w.curRTT < w.minRTT+20*time.Millisecond && w.targetRate < w.maxRate {
|
||||
w.targetRate += w.minRate
|
||||
}
|
||||
|
||||
if debugCC {
|
||||
w.log()
|
||||
log.Println("RTT", w.curRTT, "target rate", w.targetRate, "current rate", w.currentRate, "current window", w.currentWindow)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *windowCC) log() {
|
||||
if w.statsFile == nil {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w.statsFile, "%.02f,%d,%d,%d,%d,%d,%d,%.02f,%.02f\n", time.Since(w.start).Seconds()*1000, w.minWindow, w.maxWindow, w.currentWindow, w.minRate, w.maxRate, w.currentRate, w.minRTT.Seconds()*1000, w.curRTT.Seconds()*1000)
|
||||
}
|
||||
52
core/lib/buf/leakybuf.go
Normal file
52
core/lib/buf/leakybuf.go
Normal file
@ -0,0 +1,52 @@
|
||||
// Provides leaky buffer, based on the example in Effective Go.
|
||||
package buf
|
||||
|
||||
type LeakyBuf struct {
|
||||
bufSize int // size of each buffer
|
||||
freeList chan []byte
|
||||
}
|
||||
|
||||
const LeakyBufSize = 2048 // data.len(2) + hmacsha1(10) + data(4096)
|
||||
const maxNBuf = 2048
|
||||
|
||||
var LeakyBuffer = NewLeakyBuf(maxNBuf, LeakyBufSize)
|
||||
|
||||
func Get() (b []byte) {
|
||||
return LeakyBuffer.Get()
|
||||
}
|
||||
func Put(b []byte) {
|
||||
LeakyBuffer.Put(b)
|
||||
}
|
||||
|
||||
// NewLeakyBuf creates a leaky buffer which can hold at most n buffer, each
|
||||
// with bufSize bytes.
|
||||
func NewLeakyBuf(n, bufSize int) *LeakyBuf {
|
||||
return &LeakyBuf{
|
||||
bufSize: bufSize,
|
||||
freeList: make(chan []byte, n),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a buffer from the leaky buffer or create a new buffer.
|
||||
func (lb *LeakyBuf) Get() (b []byte) {
|
||||
select {
|
||||
case b = <-lb.freeList:
|
||||
default:
|
||||
b = make([]byte, lb.bufSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Put add the buffer into the free buffer pool for reuse. Panic if the buffer
|
||||
// size is not the same with the leaky buffer's. This is intended to expose
|
||||
// error usage of leaky buffer.
|
||||
func (lb *LeakyBuf) Put(b []byte) {
|
||||
if len(b) != lb.bufSize {
|
||||
panic("invalid buffer size that's put into leaky buffer")
|
||||
}
|
||||
select {
|
||||
case lb.freeList <- b:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
68
core/lib/ioutils/utils.go
Normal file
68
core/lib/ioutils/utils.go
Normal file
@ -0,0 +1,68 @@
|
||||
package ioutils
|
||||
|
||||
import (
|
||||
"io"
|
||||
logger "log"
|
||||
|
||||
lbuf "github.com/snail007/goproxy/core/lib/buf"
|
||||
)
|
||||
|
||||
func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) {
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("bind crashed %s", err)
|
||||
}
|
||||
}()
|
||||
e1 := make(chan interface{}, 1)
|
||||
e2 := make(chan interface{}, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("bind crashed %s", err)
|
||||
}
|
||||
}()
|
||||
//_, err := io.Copy(dst, src)
|
||||
err := ioCopy(dst, src)
|
||||
e1 <- err
|
||||
}()
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("bind crashed %s", err)
|
||||
}
|
||||
}()
|
||||
//_, err := io.Copy(src, dst)
|
||||
err := ioCopy(src, dst)
|
||||
e2 <- err
|
||||
}()
|
||||
var err interface{}
|
||||
select {
|
||||
case err = <-e1:
|
||||
//log.Printf("e1")
|
||||
case err = <-e2:
|
||||
//log.Printf("e2")
|
||||
}
|
||||
src.Close()
|
||||
dst.Close()
|
||||
if fn != nil {
|
||||
fn(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) {
|
||||
buf := lbuf.LeakyBuffer.Get()
|
||||
defer lbuf.LeakyBuffer.Put(buf)
|
||||
n := 0
|
||||
for {
|
||||
n, err = src.Read(buf)
|
||||
if n > 0 {
|
||||
if _, e := dst.Write(buf[0:n]); e != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
24
core/lib/kcpcfg/args.go
Normal file
24
core/lib/kcpcfg/args.go
Normal file
@ -0,0 +1,24 @@
|
||||
package kcpcfg
|
||||
|
||||
import kcp "github.com/xtaci/kcp-go"
|
||||
|
||||
type KCPConfigArgs struct {
|
||||
Key *string
|
||||
Crypt *string
|
||||
Mode *string
|
||||
MTU *int
|
||||
SndWnd *int
|
||||
RcvWnd *int
|
||||
DataShard *int
|
||||
ParityShard *int
|
||||
DSCP *int
|
||||
NoComp *bool
|
||||
AckNodelay *bool
|
||||
NoDelay *int
|
||||
Interval *int
|
||||
Resend *int
|
||||
NoCongestion *int
|
||||
SockBuf *int
|
||||
KeepAlive *int
|
||||
Block kcp.BlockCrypt
|
||||
}
|
||||
315
core/lib/mapx/map.go
Normal file
315
core/lib/mapx/map.go
Normal file
@ -0,0 +1,315 @@
|
||||
package mapx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var SHARD_COUNT = 32
|
||||
|
||||
// A "thread" safe map of type string:Anything.
|
||||
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
|
||||
type ConcurrentMap []*ConcurrentMapShared
|
||||
|
||||
// A "thread" safe string to anything map.
|
||||
type ConcurrentMapShared struct {
|
||||
items map[string]interface{}
|
||||
sync.RWMutex // Read Write mutex, guards access to internal map.
|
||||
}
|
||||
|
||||
// Creates a new concurrent map.
|
||||
func NewConcurrentMap() ConcurrentMap {
|
||||
m := make(ConcurrentMap, SHARD_COUNT)
|
||||
for i := 0; i < SHARD_COUNT; i++ {
|
||||
m[i] = &ConcurrentMapShared{items: make(map[string]interface{})}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Returns shard under given key
|
||||
func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared {
|
||||
return m[uint(fnv32(key))%uint(SHARD_COUNT)]
|
||||
}
|
||||
|
||||
func (m ConcurrentMap) MSet(data map[string]interface{}) {
|
||||
for key, value := range data {
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
shard.items[key] = value
|
||||
shard.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the given value under the specified key.
|
||||
func (m ConcurrentMap) Set(key string, value interface{}) {
|
||||
// Get map shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
shard.items[key] = value
|
||||
shard.Unlock()
|
||||
}
|
||||
|
||||
// Callback to return new element to be inserted into the map
|
||||
// It is called while lock is held, therefore it MUST NOT
|
||||
// try to access other keys in same map, as it can lead to deadlock since
|
||||
// Go sync.RWLock is not reentrant
|
||||
type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{}
|
||||
|
||||
// Insert or Update - updates existing element or inserts a new one using UpsertCb
|
||||
func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) {
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
v, ok := shard.items[key]
|
||||
res = cb(ok, v, value)
|
||||
shard.items[key] = res
|
||||
shard.Unlock()
|
||||
return res
|
||||
}
|
||||
|
||||
// Sets the given value under the specified key if no value was associated with it.
|
||||
func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
|
||||
// Get map shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
_, ok := shard.items[key]
|
||||
if !ok {
|
||||
shard.items[key] = value
|
||||
}
|
||||
shard.Unlock()
|
||||
return !ok
|
||||
}
|
||||
|
||||
// Retrieves an element from map under given key.
|
||||
func (m ConcurrentMap) Get(key string) (interface{}, bool) {
|
||||
// Get shard
|
||||
shard := m.GetShard(key)
|
||||
shard.RLock()
|
||||
// Get item from shard.
|
||||
val, ok := shard.items[key]
|
||||
shard.RUnlock()
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Returns the number of elements within the map.
|
||||
func (m ConcurrentMap) Count() int {
|
||||
count := 0
|
||||
for i := 0; i < SHARD_COUNT; i++ {
|
||||
shard := m[i]
|
||||
shard.RLock()
|
||||
count += len(shard.items)
|
||||
shard.RUnlock()
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Looks up an item under specified key
|
||||
func (m ConcurrentMap) Has(key string) bool {
|
||||
// Get shard
|
||||
shard := m.GetShard(key)
|
||||
shard.RLock()
|
||||
// See if element is within shard.
|
||||
_, ok := shard.items[key]
|
||||
shard.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// Removes an element from the map.
|
||||
func (m ConcurrentMap) Remove(key string) {
|
||||
// Try to get shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
delete(shard.items, key)
|
||||
shard.Unlock()
|
||||
}
|
||||
|
||||
// Removes an element from the map and returns it
|
||||
func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
|
||||
// Try to get shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
v, exists = shard.items[key]
|
||||
delete(shard.items, key)
|
||||
shard.Unlock()
|
||||
return v, exists
|
||||
}
|
||||
|
||||
// Checks if map is empty.
|
||||
func (m ConcurrentMap) IsEmpty() bool {
|
||||
return m.Count() == 0
|
||||
}
|
||||
|
||||
// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
|
||||
type Tuple struct {
|
||||
Key string
|
||||
Val interface{}
|
||||
}
|
||||
|
||||
// Returns an iterator which could be used in a for range loop.
|
||||
//
|
||||
// Deprecated: using IterBuffered() will get a better performence
|
||||
func (m ConcurrentMap) Iter() <-chan Tuple {
|
||||
chans := snapshot(m)
|
||||
ch := make(chan Tuple)
|
||||
go fanIn(chans, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Returns a buffered iterator which could be used in a for range loop.
|
||||
func (m ConcurrentMap) IterBuffered() <-chan Tuple {
|
||||
chans := snapshot(m)
|
||||
total := 0
|
||||
for _, c := range chans {
|
||||
total += cap(c)
|
||||
}
|
||||
ch := make(chan Tuple, total)
|
||||
go fanIn(chans, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Returns a array of channels that contains elements in each shard,
|
||||
// which likely takes a snapshot of `m`.
|
||||
// It returns once the size of each buffered channel is determined,
|
||||
// before all the channels are populated using goroutines.
|
||||
func snapshot(m ConcurrentMap) (chans []chan Tuple) {
|
||||
chans = make([]chan Tuple, SHARD_COUNT)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(SHARD_COUNT)
|
||||
// Foreach shard.
|
||||
for index, shard := range m {
|
||||
go func(index int, shard *ConcurrentMapShared) {
|
||||
// Foreach key, value pair.
|
||||
shard.RLock()
|
||||
chans[index] = make(chan Tuple, len(shard.items))
|
||||
wg.Done()
|
||||
for key, val := range shard.items {
|
||||
chans[index] <- Tuple{key, val}
|
||||
}
|
||||
shard.RUnlock()
|
||||
close(chans[index])
|
||||
}(index, shard)
|
||||
}
|
||||
wg.Wait()
|
||||
return chans
|
||||
}
|
||||
|
||||
// fanIn reads elements from channels `chans` into channel `out`
|
||||
func fanIn(chans []chan Tuple, out chan Tuple) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(chans))
|
||||
for _, ch := range chans {
|
||||
go func(ch chan Tuple) {
|
||||
for t := range ch {
|
||||
out <- t
|
||||
}
|
||||
wg.Done()
|
||||
}(ch)
|
||||
}
|
||||
wg.Wait()
|
||||
close(out)
|
||||
}
|
||||
|
||||
// Returns all items as map[string]interface{}
|
||||
func (m ConcurrentMap) Items() map[string]interface{} {
|
||||
tmp := make(map[string]interface{})
|
||||
|
||||
// Insert items to temporary map.
|
||||
for item := range m.IterBuffered() {
|
||||
tmp[item.Key] = item.Val
|
||||
}
|
||||
|
||||
return tmp
|
||||
}
|
||||
|
||||
// Iterator callback,called for every key,value found in
|
||||
// maps. RLock is held for all calls for a given shard
|
||||
// therefore callback sess consistent view of a shard,
|
||||
// but not across the shards
|
||||
type IterCb func(key string, v interface{})
|
||||
|
||||
// Callback based iterator, cheapest way to read
|
||||
// all elements in a map.
|
||||
func (m ConcurrentMap) IterCb(fn IterCb) {
|
||||
for idx := range m {
|
||||
shard := (m)[idx]
|
||||
shard.RLock()
|
||||
for key, value := range shard.items {
|
||||
fn(key, value)
|
||||
}
|
||||
shard.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Return all keys as []string
|
||||
func (m ConcurrentMap) Keys() []string {
|
||||
count := m.Count()
|
||||
ch := make(chan string, count)
|
||||
go func() {
|
||||
// Foreach shard.
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(SHARD_COUNT)
|
||||
for _, shard := range m {
|
||||
go func(shard *ConcurrentMapShared) {
|
||||
// Foreach key, value pair.
|
||||
shard.RLock()
|
||||
for key := range shard.items {
|
||||
ch <- key
|
||||
}
|
||||
shard.RUnlock()
|
||||
wg.Done()
|
||||
}(shard)
|
||||
}
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
// Generate keys
|
||||
keys := make([]string, 0, count)
|
||||
for k := range ch {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
//Reviles ConcurrentMap "private" variables to json marshal.
|
||||
func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
|
||||
// Create a temporary map, which will hold all item spread across shards.
|
||||
tmp := make(map[string]interface{})
|
||||
|
||||
// Insert items to temporary map.
|
||||
for item := range m.IterBuffered() {
|
||||
tmp[item.Key] = item.Val
|
||||
}
|
||||
return json.Marshal(tmp)
|
||||
}
|
||||
|
||||
func fnv32(key string) uint32 {
|
||||
hash := uint32(2166136261)
|
||||
const prime32 = uint32(16777619)
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash *= prime32
|
||||
hash ^= uint32(key[i])
|
||||
}
|
||||
return hash
|
||||
}
|
||||
|
||||
// Concurrent map uses Interface{} as its value, therefor JSON Unmarshal
|
||||
// will probably won't know which to type to unmarshal into, in such case
|
||||
// we'll end up with a value of type map[string]interface{}, In most cases this isn't
|
||||
// out value type, this is why we've decided to remove this functionality.
|
||||
|
||||
// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) {
|
||||
// // Reverse process of Marshal.
|
||||
|
||||
// tmp := make(map[string]interface{})
|
||||
|
||||
// // Unmarshal into a single map.
|
||||
// if err := json.Unmarshal(b, &tmp); err != nil {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // foreach key,value pair in temporary map insert into our concurrent map.
|
||||
// for key, val := range tmp {
|
||||
// m.Set(key, val)
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
159
core/lib/socks5/socks5.go
Normal file
159
core/lib/socks5/socks5.go
Normal file
@ -0,0 +1,159 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
Method_NO_AUTH = uint8(0x00)
|
||||
Method_GSSAPI = uint8(0x01)
|
||||
Method_USER_PASS = uint8(0x02)
|
||||
Method_IANA = uint8(0x7F)
|
||||
Method_RESVERVE = uint8(0x80)
|
||||
Method_NONE_ACCEPTABLE = uint8(0xFF)
|
||||
VERSION_V5 = uint8(0x05)
|
||||
CMD_CONNECT = uint8(0x01)
|
||||
CMD_BIND = uint8(0x02)
|
||||
CMD_ASSOCIATE = uint8(0x03)
|
||||
ATYP_IPV4 = uint8(0x01)
|
||||
ATYP_DOMAIN = uint8(0x03)
|
||||
ATYP_IPV6 = uint8(0x04)
|
||||
REP_SUCCESS = uint8(0x00)
|
||||
REP_REQ_FAIL = uint8(0x01)
|
||||
REP_RULE_FORBIDDEN = uint8(0x02)
|
||||
REP_NETWOR_UNREACHABLE = uint8(0x03)
|
||||
REP_HOST_UNREACHABLE = uint8(0x04)
|
||||
REP_CONNECTION_REFUSED = uint8(0x05)
|
||||
REP_TTL_TIMEOUT = uint8(0x06)
|
||||
REP_CMD_UNSUPPORTED = uint8(0x07)
|
||||
REP_ATYP_UNSUPPORTED = uint8(0x08)
|
||||
REP_UNKNOWN = uint8(0x09)
|
||||
RSV = uint8(0x00)
|
||||
)
|
||||
|
||||
var (
|
||||
ZERO_IP = []byte{0x00, 0x00, 0x00, 0x00}
|
||||
ZERO_PORT = []byte{0x00, 0x00}
|
||||
)
|
||||
var Socks5Errors = []string{
|
||||
"",
|
||||
"general failure",
|
||||
"connection forbidden",
|
||||
"network unreachable",
|
||||
"host unreachable",
|
||||
"connection refused",
|
||||
"TTL expired",
|
||||
"command not supported",
|
||||
"address type not supported",
|
||||
}
|
||||
|
||||
// Auth contains authentication parameters that specific Dialers may require.
|
||||
type UsernamePassword struct {
|
||||
Username, Password string
|
||||
}
|
||||
|
||||
type PacketUDP struct {
|
||||
rsv uint16
|
||||
frag uint8
|
||||
atype uint8
|
||||
dstHost string
|
||||
dstPort string
|
||||
data []byte
|
||||
}
|
||||
|
||||
func NewPacketUDP() (p PacketUDP) {
|
||||
return PacketUDP{}
|
||||
}
|
||||
func (p *PacketUDP) Build(destAddr string, data []byte) (err error) {
|
||||
host, port, err := net.SplitHostPort(destAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.rsv = 0
|
||||
p.frag = 0
|
||||
p.dstHost = host
|
||||
p.dstPort = port
|
||||
p.atype = ATYP_IPV4
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
p.atype = ATYP_IPV4
|
||||
ip = ip4
|
||||
} else {
|
||||
p.atype = ATYP_IPV6
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
err = errors.New("proxy: destination host name too long: " + host)
|
||||
return
|
||||
}
|
||||
p.atype = ATYP_DOMAIN
|
||||
}
|
||||
p.data = data
|
||||
|
||||
return
|
||||
}
|
||||
func (p *PacketUDP) Parse(b []byte) (err error) {
|
||||
p.frag = uint8(b[2])
|
||||
if p.frag != 0 {
|
||||
err = fmt.Errorf("FRAG only support for 0 , %v ,%v", p.frag, b[:4])
|
||||
return
|
||||
}
|
||||
portIndex := 0
|
||||
p.atype = b[3]
|
||||
switch p.atype {
|
||||
case ATYP_IPV4: //IP V4
|
||||
p.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String()
|
||||
portIndex = 8
|
||||
case ATYP_DOMAIN: //域名
|
||||
domainLen := uint8(b[4])
|
||||
p.dstHost = string(b[5 : 5+domainLen]) //b[4]表示域名的长度
|
||||
portIndex = int(5 + domainLen)
|
||||
case ATYP_IPV6: //IP V6
|
||||
p.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String()
|
||||
portIndex = 20
|
||||
}
|
||||
p.dstPort = strconv.Itoa(int(b[portIndex])<<8 | int(b[portIndex+1]))
|
||||
p.data = b[portIndex+2:]
|
||||
return
|
||||
}
|
||||
func (p *PacketUDP) Header() []byte {
|
||||
header := new(bytes.Buffer)
|
||||
header.Write([]byte{0x00, 0x00, p.frag, p.atype})
|
||||
if p.atype == ATYP_IPV4 {
|
||||
ip := net.ParseIP(p.dstHost)
|
||||
header.Write(ip.To4())
|
||||
} else if p.atype == ATYP_IPV6 {
|
||||
ip := net.ParseIP(p.dstHost)
|
||||
header.Write(ip.To16())
|
||||
} else if p.atype == ATYP_DOMAIN {
|
||||
hBytes := []byte(p.dstHost)
|
||||
header.WriteByte(byte(len(hBytes)))
|
||||
header.Write(hBytes)
|
||||
}
|
||||
port, _ := strconv.ParseUint(p.dstPort, 10, 64)
|
||||
portBytes := new(bytes.Buffer)
|
||||
binary.Write(portBytes, binary.BigEndian, port)
|
||||
header.Write(portBytes.Bytes()[portBytes.Len()-2:])
|
||||
return header.Bytes()
|
||||
}
|
||||
func (p *PacketUDP) Bytes() []byte {
|
||||
packBytes := new(bytes.Buffer)
|
||||
packBytes.Write(p.Header())
|
||||
packBytes.Write(p.data)
|
||||
return packBytes.Bytes()
|
||||
}
|
||||
func (p *PacketUDP) Host() string {
|
||||
return p.dstHost
|
||||
}
|
||||
|
||||
func (p *PacketUDP) Port() string {
|
||||
return p.dstPort
|
||||
}
|
||||
func (p *PacketUDP) Data() []byte {
|
||||
return p.data
|
||||
}
|
||||
59
core/lib/transport/compress.go
Normal file
59
core/lib/transport/compress.go
Normal file
@ -0,0 +1,59 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/golang/snappy"
|
||||
)
|
||||
|
||||
func NewCompStream(conn net.Conn) *CompStream {
|
||||
c := new(CompStream)
|
||||
c.conn = conn
|
||||
c.w = snappy.NewBufferedWriter(conn)
|
||||
c.r = snappy.NewReader(conn)
|
||||
return c
|
||||
}
|
||||
func NewCompConn(conn net.Conn) net.Conn {
|
||||
c := CompStream{}
|
||||
c.conn = conn
|
||||
c.w = snappy.NewBufferedWriter(conn)
|
||||
c.r = snappy.NewReader(conn)
|
||||
return &c
|
||||
}
|
||||
|
||||
type CompStream struct {
|
||||
net.Conn
|
||||
conn net.Conn
|
||||
w *snappy.Writer
|
||||
r *snappy.Reader
|
||||
}
|
||||
|
||||
func (c *CompStream) Read(p []byte) (n int, err error) {
|
||||
return c.r.Read(p)
|
||||
}
|
||||
|
||||
func (c *CompStream) Write(p []byte) (n int, err error) {
|
||||
n, err = c.w.Write(p)
|
||||
err = c.w.Flush()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *CompStream) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
func (c *CompStream) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
func (c *CompStream) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
func (c *CompStream) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
func (c *CompStream) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
func (c *CompStream) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
40
core/lib/transport/encrypt/conn.go
Normal file
40
core/lib/transport/encrypt/conn.go
Normal file
@ -0,0 +1,40 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
lbuf "github.com/snail007/goproxy/core/lib/buf"
|
||||
)
|
||||
|
||||
var (
|
||||
lBuf = lbuf.NewLeakyBuf(2048, 2048)
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
*Cipher
|
||||
w io.Writer
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func NewConn(c net.Conn, method, password string) (conn net.Conn, err error) {
|
||||
cipher0, err := NewCipher(method, password)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn = &Conn{
|
||||
Conn: c,
|
||||
Cipher: cipher0,
|
||||
r: &cipher.StreamReader{S: cipher0.ReadStream, R: c},
|
||||
w: &cipher.StreamWriter{S: cipher0.WriteStream, W: c},
|
||||
}
|
||||
return
|
||||
}
|
||||
func (s *Conn) Read(b []byte) (n int, err error) {
|
||||
return s.r.Read(b)
|
||||
}
|
||||
func (s *Conn) Write(b []byte) (n int, err error) {
|
||||
return s.w.Write(b)
|
||||
}
|
||||
185
core/lib/transport/encrypt/encrypt.go
Normal file
185
core/lib/transport/encrypt/encrypt.go
Normal file
@ -0,0 +1,185 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/des"
|
||||
"crypto/md5"
|
||||
"crypto/rc4"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
|
||||
lbuf "github.com/snail007/goproxy/core/lib/buf"
|
||||
"github.com/Yawning/chacha20"
|
||||
"golang.org/x/crypto/blowfish"
|
||||
"golang.org/x/crypto/cast5"
|
||||
)
|
||||
|
||||
const leakyBufSize = 2048
|
||||
const maxNBuf = 2048
|
||||
|
||||
var leakyBuf = lbuf.NewLeakyBuf(maxNBuf, leakyBufSize)
|
||||
var errEmptyPassword = errors.New("proxy key")
|
||||
|
||||
func md5sum(d []byte) []byte {
|
||||
h := md5.New()
|
||||
h.Write(d)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func evpBytesToKey(password string, keyLen int) (key []byte) {
|
||||
const md5Len = 16
|
||||
cnt := (keyLen-1)/md5Len + 1
|
||||
m := make([]byte, cnt*md5Len)
|
||||
copy(m, md5sum([]byte(password)))
|
||||
|
||||
// Repeatedly call md5 until bytes generated is enough.
|
||||
// Each call to md5 uses data: prev md5 sum + password.
|
||||
d := make([]byte, md5Len+len(password))
|
||||
start := 0
|
||||
for i := 1; i < cnt; i++ {
|
||||
start += md5Len
|
||||
copy(d, m[start-md5Len:start])
|
||||
copy(d[md5Len:], password)
|
||||
copy(m[start:], md5sum(d))
|
||||
}
|
||||
return m[:keyLen]
|
||||
}
|
||||
|
||||
type DecOrEnc int
|
||||
|
||||
const (
|
||||
Decrypt DecOrEnc = iota
|
||||
Encrypt
|
||||
)
|
||||
|
||||
func newStream(block cipher.Block, err error, key, iv []byte,
|
||||
doe DecOrEnc) (cipher.Stream, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if doe == Encrypt {
|
||||
return cipher.NewCFBEncrypter(block, iv), nil
|
||||
} else {
|
||||
return cipher.NewCFBDecrypter(block, iv), nil
|
||||
}
|
||||
}
|
||||
|
||||
func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cipher.NewCTR(block, iv), nil
|
||||
}
|
||||
|
||||
func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := des.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newBlowFishStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := blowfish.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newCast5Stream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := cast5.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
|
||||
h := md5.New()
|
||||
h.Write(key)
|
||||
h.Write(iv)
|
||||
rc4key := h.Sum(nil)
|
||||
|
||||
return rc4.NewCipher(rc4key)
|
||||
}
|
||||
|
||||
func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
|
||||
return chacha20.NewCipher(key, iv)
|
||||
}
|
||||
|
||||
func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
|
||||
return chacha20.NewCipher(key, iv)
|
||||
}
|
||||
|
||||
type cipherInfo struct {
|
||||
keyLen int
|
||||
ivLen int
|
||||
newStream func(key, iv []byte, doe DecOrEnc) (cipher.Stream, error)
|
||||
}
|
||||
|
||||
var cipherMethod = map[string]*cipherInfo{
|
||||
"aes-128-cfb": {16, 16, newAESCFBStream},
|
||||
"aes-192-cfb": {24, 16, newAESCFBStream},
|
||||
"aes-256-cfb": {32, 16, newAESCFBStream},
|
||||
"aes-128-ctr": {16, 16, newAESCTRStream},
|
||||
"aes-192-ctr": {24, 16, newAESCTRStream},
|
||||
"aes-256-ctr": {32, 16, newAESCTRStream},
|
||||
"des-cfb": {8, 8, newDESStream},
|
||||
"bf-cfb": {16, 8, newBlowFishStream},
|
||||
"cast5-cfb": {16, 8, newCast5Stream},
|
||||
"rc4-md5": {16, 16, newRC4MD5Stream},
|
||||
"rc4-md5-6": {16, 6, newRC4MD5Stream},
|
||||
"chacha20": {32, 8, newChaCha20Stream},
|
||||
"chacha20-ietf": {32, 12, newChaCha20IETFStream},
|
||||
}
|
||||
|
||||
func GetCipherMethods() (keys []string) {
|
||||
keys = []string{}
|
||||
for k := range cipherMethod {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return
|
||||
}
|
||||
func CheckCipherMethod(method string) error {
|
||||
if method == "" {
|
||||
method = "aes-256-cfb"
|
||||
}
|
||||
_, ok := cipherMethod[method]
|
||||
if !ok {
|
||||
return errors.New("Unsupported encryption method: " + method)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Cipher struct {
|
||||
WriteStream cipher.Stream
|
||||
ReadStream cipher.Stream
|
||||
key []byte
|
||||
info *cipherInfo
|
||||
}
|
||||
|
||||
func NewCipher(method, password string) (c *Cipher, err error) {
|
||||
if password == "" {
|
||||
return nil, errEmptyPassword
|
||||
}
|
||||
mi, ok := cipherMethod[method]
|
||||
if !ok {
|
||||
return nil, errors.New("Unsupported encryption method: " + method)
|
||||
}
|
||||
key := evpBytesToKey(password, mi.keyLen)
|
||||
c = &Cipher{key: key, info: mi}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//hash(key) -> read IV
|
||||
riv := sha256.New().Sum(c.key)[:c.info.ivLen]
|
||||
c.ReadStream, err = c.info.newStream(c.key, riv, Decrypt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} //hash(read IV) -> write IV
|
||||
wiv := sha256.New().Sum(riv)[:c.info.ivLen]
|
||||
c.WriteStream, err = c.info.newStream(c.key, wiv, Encrypt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
212
core/lib/udp/udp.go
Normal file
212
core/lib/udp/udp.go
Normal file
@ -0,0 +1,212 @@
|
||||
package udputils
|
||||
|
||||
import (
|
||||
logger "log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
bufx "github.com/snail007/goproxy/core/lib/buf"
|
||||
mapx "github.com/snail007/goproxy/core/lib/mapx"
|
||||
)
|
||||
|
||||
type CreateOutUDPConnFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, packet []byte) (outconn *net.UDPConn, err error)
|
||||
type CleanFn func(srcAddr string)
|
||||
type BeforeSendFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, b []byte) (sendB []byte, err error)
|
||||
type BeforeReplyFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, outconn *net.UDPConn, b []byte) (replyB []byte, err error)
|
||||
|
||||
type IOBinder struct {
|
||||
outConns mapx.ConcurrentMap
|
||||
listener *net.UDPConn
|
||||
createOutUDPConnFn CreateOutUDPConnFn
|
||||
log *logger.Logger
|
||||
timeout time.Duration
|
||||
cleanFn CleanFn
|
||||
inTCPConn *net.Conn
|
||||
outTCPConn *net.Conn
|
||||
beforeSendFn BeforeSendFn
|
||||
beforeReplyFn BeforeReplyFn
|
||||
}
|
||||
|
||||
func NewIOBinder(listener *net.UDPConn, log *logger.Logger) *IOBinder {
|
||||
return &IOBinder{
|
||||
listener: listener,
|
||||
outConns: mapx.NewConcurrentMap(),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
func (s *IOBinder) Factory(fn CreateOutUDPConnFn) *IOBinder {
|
||||
s.createOutUDPConnFn = fn
|
||||
return s
|
||||
}
|
||||
func (s *IOBinder) AfterReadFromClient(fn BeforeSendFn) *IOBinder {
|
||||
s.beforeSendFn = fn
|
||||
return s
|
||||
}
|
||||
func (s *IOBinder) AfterReadFromServer(fn BeforeReplyFn) *IOBinder {
|
||||
s.beforeReplyFn = fn
|
||||
return s
|
||||
}
|
||||
func (s *IOBinder) Timeout(timeout time.Duration) *IOBinder {
|
||||
s.timeout = timeout
|
||||
return s
|
||||
}
|
||||
func (s *IOBinder) Clean(fn CleanFn) *IOBinder {
|
||||
s.cleanFn = fn
|
||||
return s
|
||||
}
|
||||
func (s *IOBinder) AliveWithServeConn(srcAddr string, inTCPConn *net.Conn) *IOBinder {
|
||||
s.inTCPConn = inTCPConn
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
(*inTCPConn).SetReadDeadline(time.Time{})
|
||||
if _, err := (*inTCPConn).Read(buf); err != nil {
|
||||
s.log.Printf("udp related tcp conn of client disconnected with read , %s", err.Error())
|
||||
s.clean(srcAddr)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
(*inTCPConn).SetWriteDeadline(time.Now().Add(time.Second * 5))
|
||||
if _, err := (*inTCPConn).Write([]byte{0x00}); err != nil {
|
||||
s.log.Printf("udp related tcp conn of client disconnected with write , %s", err.Error())
|
||||
s.clean(srcAddr)
|
||||
return
|
||||
}
|
||||
(*inTCPConn).SetWriteDeadline(time.Time{})
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}()
|
||||
return s
|
||||
}
|
||||
func (s *IOBinder) AliveWithClientConn(srcAddr string, outTCPConn *net.Conn) *IOBinder {
|
||||
s.outTCPConn = outTCPConn
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
(*outTCPConn).SetReadDeadline(time.Time{})
|
||||
if _, err := (*outTCPConn).Read(buf); err != nil {
|
||||
s.log.Printf("udp related tcp conn to parent disconnected with read , %s", err.Error())
|
||||
s.clean(srcAddr)
|
||||
}
|
||||
}()
|
||||
return s
|
||||
}
|
||||
func (s *IOBinder) Run() (err error) {
|
||||
var (
|
||||
isClosedErr = func(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "use of closed network connection")
|
||||
}
|
||||
isTimeoutErr = func(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
e, ok := err.(net.Error)
|
||||
return ok && e.Timeout()
|
||||
}
|
||||
isRefusedErr = func(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "connection refused")
|
||||
}
|
||||
)
|
||||
for {
|
||||
buf := bufx.Get()
|
||||
defer bufx.Put(buf)
|
||||
n, srcAddr, err := s.listener.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
s.log.Printf("read from client error %s", err)
|
||||
if isClosedErr(err) {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
var data []byte
|
||||
if s.beforeSendFn != nil {
|
||||
data, err = s.beforeSendFn(s.listener, srcAddr, buf[:n])
|
||||
if err != nil {
|
||||
s.log.Printf("beforeSend retured an error , %s", err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
data = buf[:n]
|
||||
}
|
||||
inconnRemoteAddr := srcAddr.String()
|
||||
var outconn *net.UDPConn
|
||||
if v, ok := s.outConns.Get(inconnRemoteAddr); !ok {
|
||||
outconn, err = s.createOutUDPConnFn(s.listener, srcAddr, data)
|
||||
if err != nil {
|
||||
s.log.Printf("connnect fail %s", err)
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
s.clean(srcAddr.String())
|
||||
}()
|
||||
buf := bufx.Get()
|
||||
defer bufx.Put(buf)
|
||||
for {
|
||||
if s.timeout > 0 {
|
||||
outconn.SetReadDeadline(time.Now().Add(s.timeout))
|
||||
}
|
||||
n, srcAddr, err := outconn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
s.log.Printf("read from remote error %s", err)
|
||||
if isClosedErr(err) || isTimeoutErr(err) || isRefusedErr(err) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
data := buf[:n]
|
||||
if s.beforeReplyFn != nil {
|
||||
data, err = s.beforeReplyFn(s.listener, srcAddr, outconn, buf[:n])
|
||||
if err != nil {
|
||||
s.log.Printf("beforeReply retured an error , %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
_, err = s.listener.WriteTo(data, srcAddr)
|
||||
if err != nil {
|
||||
s.log.Printf("write to remote error %s", err)
|
||||
if isClosedErr(err) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
outconn = v.(*net.UDPConn)
|
||||
}
|
||||
|
||||
s.log.Printf("use decrpyted data , %v", data)
|
||||
|
||||
_, err = outconn.Write(data)
|
||||
|
||||
if err != nil {
|
||||
s.log.Printf("write to remote error %s", err)
|
||||
if isClosedErr(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
func (s *IOBinder) clean(srcAddr string) *IOBinder {
|
||||
if v, ok := s.outConns.Get(srcAddr); ok {
|
||||
(*v.(*net.UDPConn)).Close()
|
||||
s.outConns.Remove(srcAddr)
|
||||
}
|
||||
if s.inTCPConn != nil {
|
||||
(*s.inTCPConn).Close()
|
||||
}
|
||||
if s.outTCPConn != nil {
|
||||
(*s.outTCPConn).Close()
|
||||
}
|
||||
if s.cleanFn != nil {
|
||||
s.cleanFn(srcAddr)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *IOBinder) Close() {
|
||||
for _, c := range s.outConns.Items() {
|
||||
(*c.(*net.UDPConn)).Close()
|
||||
}
|
||||
}
|
||||
31
core/proxy/client/proxy.go
Normal file
31
core/proxy/client/proxy.go
Normal file
@ -0,0 +1,31 @@
|
||||
// Package proxy provides support for a variety of protocols to proxy network
|
||||
// data.
|
||||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
socks5c "github.com/snail007/goproxy/core/lib/socks5"
|
||||
socks5 "github.com/snail007/goproxy/core/proxy/client/socks5"
|
||||
)
|
||||
|
||||
// A Dialer is a means to establish a connection.
|
||||
type Dialer interface {
|
||||
// Dial connects to the given address via the proxy.
|
||||
DialConn(conn *net.Conn, network, addr string) (err error)
|
||||
}
|
||||
|
||||
// Auth contains authentication parameters that specific Dialers may require.
|
||||
type Auth struct {
|
||||
User, Password string
|
||||
}
|
||||
|
||||
func SOCKS5(timeout time.Duration, auth *Auth) (Dialer, error) {
|
||||
var a *socks5c.UsernamePassword
|
||||
if auth != nil {
|
||||
a = &socks5c.UsernamePassword{auth.User, auth.Password}
|
||||
}
|
||||
d := socks5.NewDialer(a, timeout)
|
||||
return d, nil
|
||||
}
|
||||
263
core/proxy/client/socks5/socks5.go
Normal file
263
core/proxy/client/socks5/socks5.go
Normal file
@ -0,0 +1,263 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
socks5c "github.com/snail007/goproxy/core/lib/socks5"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
timeout time.Duration
|
||||
usernamePassword *socks5c.UsernamePassword
|
||||
}
|
||||
|
||||
// NewDialer returns a new Dialer that dials through the provided
|
||||
// proxy server's network and address.
|
||||
func NewDialer(auth *socks5c.UsernamePassword, timeout time.Duration) *Dialer {
|
||||
if auth != nil && auth.Password == "" && auth.Username == "" {
|
||||
auth = nil
|
||||
}
|
||||
return &Dialer{
|
||||
usernamePassword: auth,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) DialConn(conn *net.Conn, network, addr string) (err error) {
|
||||
client := NewClientConn(conn, network, addr, d.timeout, d.usernamePassword, nil)
|
||||
err = client._Handshake()
|
||||
return
|
||||
}
|
||||
|
||||
type ClientConn struct {
|
||||
user string
|
||||
password string
|
||||
conn *net.Conn
|
||||
header []byte
|
||||
timeout time.Duration
|
||||
addr string
|
||||
network string
|
||||
udpAddr string
|
||||
}
|
||||
|
||||
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
|
||||
// with an optional username and password. See RFC 1928 and RFC 1929.
|
||||
// target must be a canonical address with a host and port.
|
||||
// network : tcp udp
|
||||
func NewClientConn(conn *net.Conn, network, target string, timeout time.Duration, auth *socks5c.UsernamePassword, header []byte) *ClientConn {
|
||||
s := &ClientConn{
|
||||
conn: conn,
|
||||
network: network,
|
||||
timeout: timeout,
|
||||
}
|
||||
if auth != nil {
|
||||
s.user = auth.Username
|
||||
s.password = auth.Password
|
||||
}
|
||||
if header != nil && len(header) > 0 {
|
||||
s.header = header
|
||||
}
|
||||
if network == "udp" && target == "" {
|
||||
target = "0.0.0.0:1"
|
||||
}
|
||||
s.addr = target
|
||||
return s
|
||||
}
|
||||
|
||||
// connect takes an existing connection to a socks5 proxy server,
|
||||
// and commands the server to extend that connection to target,
|
||||
// which must be a canonical address with a host and port.
|
||||
func (s *ClientConn) _Handshake() error {
|
||||
host, portStr, err := net.SplitHostPort(s.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return errors.New("proxy: failed to parse port number: " + portStr)
|
||||
}
|
||||
if port < 1 || port > 0xffff {
|
||||
return errors.New("proxy: port number out of range: " + portStr)
|
||||
}
|
||||
|
||||
if err := s.auth(host); err != nil {
|
||||
return err
|
||||
}
|
||||
buf := []byte{}
|
||||
if s.network == "tcp" {
|
||||
buf = append(buf, socks5c.VERSION_V5, socks5c.CMD_CONNECT, 0 /* reserved */)
|
||||
|
||||
} else {
|
||||
buf = append(buf, socks5c.VERSION_V5, socks5c.CMD_ASSOCIATE, 0 /* reserved */)
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
buf = append(buf, socks5c.ATYP_IPV4)
|
||||
ip = ip4
|
||||
} else {
|
||||
buf = append(buf, socks5c.ATYP_IPV6)
|
||||
}
|
||||
buf = append(buf, ip...)
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return errors.New("proxy: destination host name too long: " + host)
|
||||
}
|
||||
buf = append(buf, socks5c.ATYP_DOMAIN)
|
||||
buf = append(buf, byte(len(host)))
|
||||
buf = append(buf, host...)
|
||||
}
|
||||
buf = append(buf, byte(port>>8), byte(port))
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := (*s.conn).Write(buf); err != nil {
|
||||
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := io.ReadFull((*s.conn), buf[:4]); err != nil {
|
||||
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
failure := "unknown error"
|
||||
if int(buf[1]) < len(socks5c.Socks5Errors) {
|
||||
failure = socks5c.Socks5Errors[buf[1]]
|
||||
}
|
||||
|
||||
if len(failure) > 0 {
|
||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
|
||||
}
|
||||
|
||||
bytesToDiscard := 0
|
||||
switch buf[3] {
|
||||
case socks5c.ATYP_IPV4:
|
||||
bytesToDiscard = net.IPv4len
|
||||
case socks5c.ATYP_IPV6:
|
||||
bytesToDiscard = net.IPv6len
|
||||
case socks5c.ATYP_DOMAIN:
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
_, err := io.ReadFull((*s.conn), buf[:1])
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
bytesToDiscard = int(buf[0])
|
||||
default:
|
||||
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
|
||||
}
|
||||
|
||||
if cap(buf) < bytesToDiscard {
|
||||
buf = make([]byte, bytesToDiscard)
|
||||
} else {
|
||||
buf = buf[:bytesToDiscard]
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := io.ReadFull((*s.conn), buf); err != nil {
|
||||
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
var ip net.IP
|
||||
ip = buf
|
||||
ipStr := ""
|
||||
if bytesToDiscard == net.IPv4len || bytesToDiscard == net.IPv6len {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipStr = ipv4.String()
|
||||
} else {
|
||||
ipStr = ip.To16().String()
|
||||
}
|
||||
}
|
||||
//log.Printf("%v", ipStr)
|
||||
// Also need to discard the port number
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil {
|
||||
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
p := binary.BigEndian.Uint16([]byte{buf[0], buf[1]})
|
||||
//log.Printf("%v", p)
|
||||
s.udpAddr = net.JoinHostPort(ipStr, fmt.Sprintf("%d", p))
|
||||
//log.Printf("%v", s.udpAddr)
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
return nil
|
||||
}
|
||||
func (s *ClientConn) SendUDP(data []byte, addr string) (respData []byte, err error) {
|
||||
|
||||
c, err := net.DialTimeout("udp", s.udpAddr, s.timeout)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn := c.(*net.UDPConn)
|
||||
|
||||
p := socks5c.NewPacketUDP()
|
||||
p.Build(addr, data)
|
||||
conn.SetDeadline(time.Now().Add(s.timeout))
|
||||
conn.Write(p.Bytes())
|
||||
conn.SetDeadline(time.Time{})
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
conn.SetDeadline(time.Now().Add(s.timeout))
|
||||
n, _, err := conn.ReadFrom(buf)
|
||||
conn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
respData = buf[:n]
|
||||
return
|
||||
}
|
||||
func (s *ClientConn) auth(host string) error {
|
||||
|
||||
// the size here is just an estimate
|
||||
buf := make([]byte, 0, 6+len(host))
|
||||
|
||||
buf = append(buf, socks5c.VERSION_V5)
|
||||
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
|
||||
buf = append(buf, 2 /* num auth methods */, socks5c.Method_NO_AUTH, socks5c.Method_USER_PASS)
|
||||
} else {
|
||||
buf = append(buf, 1 /* num auth methods */, socks5c.Method_NO_AUTH)
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := (*s.conn).Write(buf); err != nil {
|
||||
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil {
|
||||
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
|
||||
if buf[0] != 5 {
|
||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
|
||||
}
|
||||
if buf[1] == 0xff {
|
||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
|
||||
}
|
||||
|
||||
// See RFC 1929
|
||||
if buf[1] == socks5c.Method_USER_PASS {
|
||||
buf = buf[:0]
|
||||
buf = append(buf, 1 /* password protocol version */)
|
||||
buf = append(buf, uint8(len(s.user)))
|
||||
buf = append(buf, s.user...)
|
||||
buf = append(buf, uint8(len(s.password)))
|
||||
buf = append(buf, s.password...)
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := (*s.conn).Write(buf); err != nil {
|
||||
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
(*s.conn).SetDeadline(time.Now().Add(s.timeout))
|
||||
if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil {
|
||||
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||
}
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
if buf[1] != 0 {
|
||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
79
core/proxy/client/tests/proxy_test.go
Normal file
79
core/proxy/client/tests/proxy_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
proxyclient "github.com/snail007/goproxy/core/proxy/client"
|
||||
sdk "github.com/snail007/goproxy/sdk/android-ios"
|
||||
)
|
||||
|
||||
func TestSocks5(t *testing.T) {
|
||||
estr := sdk.Start("s1", "socks -p :8185 --log test.log")
|
||||
if estr != "" {
|
||||
t.Fatal(estr)
|
||||
}
|
||||
p, e := proxyclient.SOCKS5(time.Second, nil)
|
||||
if e != nil {
|
||||
t.Error(e)
|
||||
} else {
|
||||
c, e := net.Dial("tcp", "127.0.0.1:8185")
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
e = p.DialConn(&c, "tcp", "www.baidu.com:80")
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
_, e = c.Write([]byte("Get / http/1.1\r\nHost: www.baidu.com\r\n"))
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
b, e := ioutil.ReadAll(c)
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
if !strings.HasPrefix(string(b), "HTTP") {
|
||||
t.Fatalf("request baidu fail:%s", string(b))
|
||||
}
|
||||
}
|
||||
sdk.Stop("s1")
|
||||
os.Remove("test.log")
|
||||
}
|
||||
|
||||
func TestSocks5Auth(t *testing.T) {
|
||||
estr := sdk.Start("s1", "socks -p :8185 -a u:p --log test.log")
|
||||
if estr != "" {
|
||||
t.Fatal(estr)
|
||||
}
|
||||
p, e := proxyclient.SOCKS5(time.Second, &proxyclient.Auth{User: "u", Password: "p"})
|
||||
if e != nil {
|
||||
t.Error(e)
|
||||
} else {
|
||||
c, e := net.Dial("tcp", "127.0.0.1:8185")
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
e = p.DialConn(&c, "tcp", "www.baidu.com:80")
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
_, e = c.Write([]byte("Get / http/1.1\r\nHost: www.baidu.com\r\n"))
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
b, e := ioutil.ReadAll(c)
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
if !strings.HasPrefix(string(b), "HTTP") {
|
||||
t.Fatalf("request baidu fail:%s", string(b))
|
||||
}
|
||||
}
|
||||
sdk.Stop("s1")
|
||||
os.Remove("test.log")
|
||||
}
|
||||
373
core/proxy/server/socks5/server.go
Normal file
373
core/proxy/server/socks5/server.go
Normal file
@ -0,0 +1,373 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
socks5c "github.com/snail007/goproxy/core/proxy/common/socks5"
|
||||
)
|
||||
|
||||
type BasicAuther interface {
|
||||
CheckUserPass(username, password, fromIP, ToTarget string) bool
|
||||
}
|
||||
type Request struct {
|
||||
ver uint8
|
||||
cmd uint8
|
||||
reserve uint8
|
||||
addressType uint8
|
||||
dstAddr string
|
||||
dstPort string
|
||||
dstHost string
|
||||
bytes []byte
|
||||
rw io.ReadWriter
|
||||
}
|
||||
|
||||
func NewRequest(rw io.ReadWriter, header ...[]byte) (req Request, err interface{}) {
|
||||
var b = make([]byte, 1024)
|
||||
var n int
|
||||
req = Request{rw: rw}
|
||||
if header != nil && len(header) == 1 && len(header[0]) > 1 {
|
||||
b = header[0]
|
||||
n = len(header[0])
|
||||
} else {
|
||||
n, err = rw.Read(b[:])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("read req data fail,ERR: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
req.ver = uint8(b[0])
|
||||
req.cmd = uint8(b[1])
|
||||
req.reserve = uint8(b[2])
|
||||
req.addressType = uint8(b[3])
|
||||
if b[0] != 0x5 {
|
||||
err = fmt.Errorf("sosck version supported")
|
||||
req.TCPReply(socks5c.REP_REQ_FAIL)
|
||||
return
|
||||
}
|
||||
switch b[3] {
|
||||
case 0x01: //IP V4
|
||||
req.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String()
|
||||
case 0x03: //域名
|
||||
req.dstHost = string(b[5 : n-2]) //b[4]表示域名的长度
|
||||
case 0x04: //IP V6
|
||||
req.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String()
|
||||
}
|
||||
req.dstPort = strconv.Itoa(int(b[n-2])<<8 | int(b[n-1]))
|
||||
req.dstAddr = net.JoinHostPort(req.dstHost, req.dstPort)
|
||||
req.bytes = b[:n]
|
||||
return
|
||||
}
|
||||
func (s *Request) Bytes() []byte {
|
||||
return s.bytes
|
||||
}
|
||||
func (s *Request) Addr() string {
|
||||
return s.dstAddr
|
||||
}
|
||||
func (s *Request) Host() string {
|
||||
return s.dstHost
|
||||
}
|
||||
func (s *Request) Port() string {
|
||||
return s.dstPort
|
||||
}
|
||||
func (s *Request) AType() uint8 {
|
||||
return s.addressType
|
||||
}
|
||||
func (s *Request) CMD() uint8 {
|
||||
return s.cmd
|
||||
}
|
||||
|
||||
func (s *Request) TCPReply(rep uint8) (err error) {
|
||||
_, err = s.rw.Write(s.NewReply(rep, "0.0.0.0:0"))
|
||||
return
|
||||
}
|
||||
func (s *Request) UDPReply(rep uint8, addr string) (err error) {
|
||||
_, err = s.rw.Write(s.NewReply(rep, addr))
|
||||
return
|
||||
}
|
||||
func (s *Request) NewReply(rep uint8, addr string) []byte {
|
||||
var response bytes.Buffer
|
||||
host, port, _ := net.SplitHostPort(addr)
|
||||
ip := net.ParseIP(host)
|
||||
ipb := ip.To4()
|
||||
atyp := socks5c.ATYP_IPV4
|
||||
ipv6 := ip.To16()
|
||||
zeroiIPv6 := fmt.Sprintf("%d%d%d%d%d%d%d%d%d%d%d%d",
|
||||
ipv6[0], ipv6[1], ipv6[2], ipv6[3],
|
||||
ipv6[4], ipv6[5], ipv6[6], ipv6[7],
|
||||
ipv6[8], ipv6[9], ipv6[10], ipv6[11],
|
||||
)
|
||||
if ipb == nil && ipv6 != nil && "0000000000255255" != zeroiIPv6 {
|
||||
atyp = socks5c.ATYP_IPV6
|
||||
ipb = ip.To16()
|
||||
}
|
||||
porti, _ := strconv.Atoi(port)
|
||||
portb := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(portb, uint16(porti))
|
||||
// log.Printf("atyp : %v", atyp)
|
||||
// log.Printf("ip : %v", []byte(ip))
|
||||
response.WriteByte(socks5c.VERSION_V5)
|
||||
response.WriteByte(rep)
|
||||
response.WriteByte(socks5c.RSV)
|
||||
response.WriteByte(atyp)
|
||||
response.Write(ipb)
|
||||
response.Write(portb)
|
||||
return response.Bytes()
|
||||
}
|
||||
|
||||
type MethodsRequest struct {
|
||||
ver uint8
|
||||
methodsCount uint8
|
||||
methods []uint8
|
||||
bytes []byte
|
||||
rw *io.ReadWriter
|
||||
}
|
||||
|
||||
func NewMethodsRequest(r io.ReadWriter, header ...[]byte) (s MethodsRequest, err interface{}) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
err = recover()
|
||||
}
|
||||
}()
|
||||
s = MethodsRequest{}
|
||||
s.rw = &r
|
||||
var buf = make([]byte, 300)
|
||||
var n int
|
||||
if header != nil && len(header) == 1 && len(header[0]) > 1 {
|
||||
buf = header[0]
|
||||
n = len(header[0])
|
||||
} else {
|
||||
n, err = r.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if buf[0] != 0x05 {
|
||||
err = fmt.Errorf("socks version not supported")
|
||||
return
|
||||
}
|
||||
if n != int(buf[1])+int(2) {
|
||||
err = fmt.Errorf("socks methods data length error")
|
||||
return
|
||||
}
|
||||
s.ver = buf[0]
|
||||
s.methodsCount = buf[1]
|
||||
s.methods = buf[2:n]
|
||||
s.bytes = buf[:n]
|
||||
return
|
||||
}
|
||||
func (s *MethodsRequest) Version() uint8 {
|
||||
return s.ver
|
||||
}
|
||||
func (s *MethodsRequest) MethodsCount() uint8 {
|
||||
return s.methodsCount
|
||||
}
|
||||
func (s *MethodsRequest) Methods() []uint8 {
|
||||
return s.methods
|
||||
}
|
||||
func (s *MethodsRequest) Select(method uint8) bool {
|
||||
for _, m := range s.methods {
|
||||
if m == method {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (s *MethodsRequest) Reply(method uint8) (err error) {
|
||||
_, err = (*s.rw).Write([]byte{byte(socks5c.VERSION_V5), byte(method)})
|
||||
return
|
||||
}
|
||||
func (s *MethodsRequest) Bytes() []byte {
|
||||
return s.bytes
|
||||
}
|
||||
|
||||
type ServerConn struct {
|
||||
target string
|
||||
user string
|
||||
password string
|
||||
conn *net.Conn
|
||||
timeout time.Duration
|
||||
auth *BasicAuther
|
||||
header []byte
|
||||
ver uint8
|
||||
//method
|
||||
methodsCount uint8
|
||||
methods []uint8
|
||||
method uint8
|
||||
//request
|
||||
cmd uint8
|
||||
reserve uint8
|
||||
addressType uint8
|
||||
dstAddr string
|
||||
dstPort string
|
||||
dstHost string
|
||||
udpAddress string
|
||||
}
|
||||
|
||||
func NewServerConn(conn *net.Conn, timeout time.Duration, auth *BasicAuther, udpAddress string, header []byte) *ServerConn {
|
||||
if udpAddress == "" {
|
||||
udpAddress = "0.0.0.0:16666"
|
||||
}
|
||||
s := &ServerConn{
|
||||
conn: conn,
|
||||
timeout: timeout,
|
||||
auth: auth,
|
||||
header: header,
|
||||
ver: socks5c.VERSION_V5,
|
||||
udpAddress: udpAddress,
|
||||
}
|
||||
return s
|
||||
|
||||
}
|
||||
func (s *ServerConn) Close() {
|
||||
(*s.conn).Close()
|
||||
}
|
||||
func (s *ServerConn) AuthData() socks5c.UsernamePassword {
|
||||
return socks5c.UsernamePassword{s.user, s.password}
|
||||
}
|
||||
func (s *ServerConn) Method() uint8 {
|
||||
return s.method
|
||||
}
|
||||
func (s *ServerConn) Target() string {
|
||||
return s.target
|
||||
}
|
||||
func (s *ServerConn) Handshake() (err error) {
|
||||
remoteAddr := (*s.conn).RemoteAddr()
|
||||
//协商开始
|
||||
//method select request
|
||||
var methodReq MethodsRequest
|
||||
(*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
|
||||
methodReq, e := NewMethodsRequest((*s.conn), s.header)
|
||||
(*s.conn).SetReadDeadline(time.Time{})
|
||||
if e != nil {
|
||||
(*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
methodReq.Reply(socks5c.Method_NONE_ACCEPTABLE)
|
||||
(*s.conn).SetReadDeadline(time.Time{})
|
||||
err = fmt.Errorf("new methods request fail,ERR: %s", e)
|
||||
return
|
||||
}
|
||||
//log.Printf("%v,s.auth == %v && methodReq.Select(Method_NO_AUTH) %v", methodReq.methods, s.auth, methodReq.Select(Method_NO_AUTH))
|
||||
if s.auth == nil && methodReq.Select(socks5c.Method_NO_AUTH) && !methodReq.Select(socks5c.Method_USER_PASS) {
|
||||
// if !methodReq.Select(Method_NO_AUTH) {
|
||||
// (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
// methodReq.Reply(Method_NONE_ACCEPTABLE)
|
||||
// (*s.conn).SetReadDeadline(time.Time{})
|
||||
// err = fmt.Errorf("none method found : Method_NO_AUTH")
|
||||
// return
|
||||
// }
|
||||
s.method = socks5c.Method_NO_AUTH
|
||||
//method select reply
|
||||
(*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
err = methodReq.Reply(socks5c.Method_NO_AUTH)
|
||||
(*s.conn).SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
err = fmt.Errorf("reply answer data fail,ERR: %s", err)
|
||||
return
|
||||
}
|
||||
// err = fmt.Errorf("% x", methodReq.Bytes())
|
||||
} else {
|
||||
//auth
|
||||
if !methodReq.Select(socks5c.Method_USER_PASS) {
|
||||
(*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
methodReq.Reply(socks5c.Method_NONE_ACCEPTABLE)
|
||||
(*s.conn).SetReadDeadline(time.Time{})
|
||||
err = fmt.Errorf("none method found : Method_USER_PASS")
|
||||
return
|
||||
}
|
||||
s.method = socks5c.Method_USER_PASS
|
||||
//method reply need auth
|
||||
(*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
err = methodReq.Reply(socks5c.Method_USER_PASS)
|
||||
(*s.conn).SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
err = fmt.Errorf("reply answer data fail,ERR: %s", err)
|
||||
return
|
||||
}
|
||||
//read auth
|
||||
buf := make([]byte, 500)
|
||||
var n int
|
||||
(*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
n, err = (*s.conn).Read(buf)
|
||||
(*s.conn).SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
err = fmt.Errorf("read auth info fail,ERR: %s", err)
|
||||
return
|
||||
}
|
||||
r := buf[:n]
|
||||
s.user = string(r[2 : r[1]+2])
|
||||
s.password = string(r[2+r[1]+1:])
|
||||
//err = fmt.Errorf("user:%s,pass:%s", user, pass)
|
||||
//auth
|
||||
_addr := strings.Split(remoteAddr.String(), ":")
|
||||
if s.auth == nil || (*s.auth).CheckUserPass(s.user, s.password, _addr[0], "") {
|
||||
(*s.conn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(s.timeout)))
|
||||
_, err = (*s.conn).Write([]byte{0x01, 0x00})
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
err = fmt.Errorf("answer auth success to %s fail,ERR: %s", remoteAddr, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
(*s.conn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(s.timeout)))
|
||||
_, err = (*s.conn).Write([]byte{0x01, 0x01})
|
||||
(*s.conn).SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
err = fmt.Errorf("answer auth fail to %s fail,ERR: %s", remoteAddr, err)
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("auth fail from %s", remoteAddr)
|
||||
return
|
||||
}
|
||||
}
|
||||
//request detail
|
||||
(*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
|
||||
request, e := NewRequest(*s.conn)
|
||||
(*s.conn).SetReadDeadline(time.Time{})
|
||||
if e != nil {
|
||||
err = fmt.Errorf("read request data fail,ERR: %s", e)
|
||||
return
|
||||
}
|
||||
//协商结束
|
||||
|
||||
switch request.CMD() {
|
||||
case socks5c.CMD_BIND:
|
||||
err = request.TCPReply(socks5c.REP_UNKNOWN)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("TCPReply REP_UNKNOWN to %s fail,ERR: %s", remoteAddr, err)
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("cmd bind not supported, form: %s", remoteAddr)
|
||||
return
|
||||
case socks5c.CMD_CONNECT:
|
||||
err = request.TCPReply(socks5c.REP_SUCCESS)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("TCPReply REP_SUCCESS to %s fail,ERR: %s", remoteAddr, err)
|
||||
return
|
||||
}
|
||||
case socks5c.CMD_ASSOCIATE:
|
||||
err = request.UDPReply(socks5c.REP_SUCCESS, s.udpAddress)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("UDPReply REP_SUCCESS to %s fail,ERR: %s", remoteAddr, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
//fill socks info
|
||||
s.target = request.Addr()
|
||||
s.methodsCount = methodReq.MethodsCount()
|
||||
s.methods = methodReq.Methods()
|
||||
s.cmd = request.CMD()
|
||||
s.reserve = request.reserve
|
||||
s.addressType = request.addressType
|
||||
s.dstAddr = request.dstAddr
|
||||
s.dstHost = request.dstHost
|
||||
s.dstPort = request.dstPort
|
||||
return
|
||||
}
|
||||
35
core/tproxy/README.md
Normal file
35
core/tproxy/README.md
Normal file
@ -0,0 +1,35 @@
|
||||
# 透传用户IP手册
|
||||
|
||||
说明:
|
||||
|
||||
通过Linux的TPROXY功能,可以实现源站服务程序可以看见客户端真实IP,实现该功能需要linux操作系统和程序都要满足一定的条件.
|
||||
|
||||
环境要求:
|
||||
|
||||
源站必须是运行在Linux上面的服务程序,同时Linux需要满足下面条件:
|
||||
|
||||
1.Linux内核版本 >= 2.6.28
|
||||
|
||||
2.判断系统是否支持TPROXY,执行:
|
||||
|
||||
grep TPROXY /boot/config-`uname -r`
|
||||
|
||||
如果输出有下面的结果说明支持.
|
||||
|
||||
CONFIG_NETFILTER_XT_TARGET_TPROXY=m
|
||||
|
||||
部署步骤:
|
||||
|
||||
1.在源站的linux系统里面每次开机启动都要用root权限执行tproxy环境设置脚本:tproxy_setup.sh
|
||||
|
||||
2.在源站的linux系统里面使用root权限执行代理proxy
|
||||
|
||||
参数 -tproxy 是开启代理的tproxy功能.
|
||||
|
||||
./proxy -tproxy
|
||||
|
||||
2.源站的程序监听的地址IP需要使用:127.0.1.1
|
||||
|
||||
比如源站以前监听的地址是: 0.0.0.0:8800 , 现在需要修改为:127.0.1.1:8800
|
||||
|
||||
3.转发规则里面源站地址必须是对应的,比如上面的:127.0.1.1:8800
|
||||
249
core/tproxy/tproxy.go
Normal file
249
core/tproxy/tproxy.go
Normal file
@ -0,0 +1,249 @@
|
||||
// Package tproxy provides the TCPDial and TCPListen tproxy equivalent of the
|
||||
// net package Dial and Listen with tproxy support for linux ONLY.
|
||||
package tproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const big = 0xFFFFFF
|
||||
const IP_ORIGADDRS = 20
|
||||
|
||||
// Debug outs the library in Debug mode
|
||||
var Debug = false
|
||||
|
||||
func ipToSocksAddr(family int, ip net.IP, port int, zone string) (unix.Sockaddr, error) {
|
||||
switch family {
|
||||
case unix.AF_INET:
|
||||
if len(ip) == 0 {
|
||||
ip = net.IPv4zero
|
||||
}
|
||||
if ip = ip.To4(); ip == nil {
|
||||
return nil, net.InvalidAddrError("non-IPv4 address")
|
||||
}
|
||||
sa := new(unix.SockaddrInet4)
|
||||
for i := 0; i < net.IPv4len; i++ {
|
||||
sa.Addr[i] = ip[i]
|
||||
}
|
||||
sa.Port = port
|
||||
return sa, nil
|
||||
case unix.AF_INET6:
|
||||
if len(ip) == 0 {
|
||||
ip = net.IPv6zero
|
||||
}
|
||||
// IPv4 callers use 0.0.0.0 to mean "announce on any available address".
|
||||
// In IPv6 mode, Linux treats that as meaning "announce on 0.0.0.0",
|
||||
// which it refuses to do. Rewrite to the IPv6 unspecified address.
|
||||
if ip.Equal(net.IPv4zero) {
|
||||
ip = net.IPv6zero
|
||||
}
|
||||
if ip = ip.To16(); ip == nil {
|
||||
return nil, net.InvalidAddrError("non-IPv6 address")
|
||||
}
|
||||
sa := new(unix.SockaddrInet6)
|
||||
for i := 0; i < net.IPv6len; i++ {
|
||||
sa.Addr[i] = ip[i]
|
||||
}
|
||||
sa.Port = port
|
||||
sa.ZoneId = uint32(zoneToInt(zone))
|
||||
return sa, nil
|
||||
}
|
||||
return nil, net.InvalidAddrError("unexpected socket family")
|
||||
}
|
||||
|
||||
func zoneToInt(zone string) int {
|
||||
if zone == "" {
|
||||
return 0
|
||||
}
|
||||
if ifi, err := net.InterfaceByName(zone); err == nil {
|
||||
return ifi.Index
|
||||
}
|
||||
n, _, _ := dtoi(zone, 0)
|
||||
return n
|
||||
}
|
||||
|
||||
func dtoi(s string, i0 int) (n int, i int, ok bool) {
|
||||
n = 0
|
||||
for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
|
||||
n = n*10 + int(s[i]-'0')
|
||||
if n >= big {
|
||||
return 0, i, false
|
||||
}
|
||||
}
|
||||
if i == i0 {
|
||||
return 0, i, false
|
||||
}
|
||||
return n, i, true
|
||||
}
|
||||
|
||||
// IPTcpAddrToUnixSocksAddr ---
|
||||
func IPTcpAddrToUnixSocksAddr(addr string) (sa unix.Sockaddr, err error) {
|
||||
if Debug {
|
||||
fmt.Println("DEBUG: IPTcpAddrToUnixSocksAddr recieved address:", addr)
|
||||
}
|
||||
addressNet := "tcp6"
|
||||
if addr[0] != '[' {
|
||||
addressNet = "tcp4"
|
||||
}
|
||||
tcpAddr, err := net.ResolveTCPAddr(addressNet, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ipToSocksAddr(ipType(addr), tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone)
|
||||
}
|
||||
|
||||
// IPv6UdpAddrToUnixSocksAddr ---
|
||||
func IPv6UdpAddrToUnixSocksAddr(addr string) (sa unix.Sockaddr, err error) {
|
||||
tcpAddr, err := net.ResolveTCPAddr("udp6", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ipToSocksAddr(unix.AF_INET6, tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone)
|
||||
}
|
||||
|
||||
// TCPListen is listening for incoming IP packets which are being intercepted.
|
||||
// In conflict to regular Listen mehtod the socket destination and source addresses
|
||||
// are of the intercepted connection.
|
||||
// Else then that it works exactly like net package net.Listen.
|
||||
func TCPListen(listenAddr string) (listener net.Listener, err error) {
|
||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer unix.Close(s)
|
||||
err = unix.SetsockoptInt(s, unix.SOL_IP, unix.IP_TRANSPARENT, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sa, err := IPTcpAddrToUnixSocksAddr(listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = unix.Bind(s, sa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = unix.Listen(s, unix.SOMAXCONN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f := os.NewFile(uintptr(s), "TProxy")
|
||||
defer f.Close()
|
||||
return net.FileListener(f)
|
||||
}
|
||||
func ipType(localAddr string) int {
|
||||
host, _, _ := net.SplitHostPort(localAddr)
|
||||
if host != "" {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil || ip.To4() != nil {
|
||||
return unix.AF_INET
|
||||
}
|
||||
return unix.AF_INET6
|
||||
}
|
||||
return unix.AF_INET
|
||||
}
|
||||
|
||||
// TCPDial is a special tcp connection which binds a non local address as the source.
|
||||
// Except then the option to bind to a specific local address which the machine doesn't posses
|
||||
// it is exactly like any other net.Conn connection.
|
||||
// It is advised to use port numbered 0 in the localAddr and leave the kernel to choose which
|
||||
// Local port to use in order to avoid errors and binding conflicts.
|
||||
func TCPDial(localAddr, remoteAddr string, timeout time.Duration) (conn net.Conn, err error) {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
if Debug {
|
||||
fmt.Println("TCPDial from:", localAddr, "to:", remoteAddr)
|
||||
}
|
||||
s, err := unix.Socket(ipType(localAddr), unix.SOCK_STREAM, 0)
|
||||
|
||||
//In a case there was a need for a non-blocking socket an example
|
||||
//s, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM |unix.SOCK_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
defer unix.Close(s)
|
||||
err = unix.SetsockoptInt(s, unix.SOL_IP, unix.IP_TRANSPARENT, 1)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Println("ERROR setting the socket in IP_TRANSPARENT mode", err)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
err = unix.SetsockoptInt(s, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Println("ERROR setting the socket in unix.SO_REUSEADDR mode", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rhost, _, err := net.SplitHostPort(localAddr)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
// fmt.Fprintln(os.Stderr, err)
|
||||
fmt.Println("ERROR", err, "running net.SplitHostPort on address:", localAddr)
|
||||
}
|
||||
}
|
||||
|
||||
sa, err := IPTcpAddrToUnixSocksAddr(rhost + ":0")
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Println("ERROR creating a hostaddres for the socker with IPTcpAddrToUnixSocksAddr", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remoteSocket, err := IPTcpAddrToUnixSocksAddr(remoteAddr)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Println("ERROR creating a remoteSocket for the socker with IPTcpAddrToUnixSocksAddr on the remote addres", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.Bind(s, sa)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errChn := make(chan error, 1)
|
||||
func() {
|
||||
err = unix.Connect(s, remoteSocket)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Println("ERROR Connecting from", s, "to:", remoteSocket, "ERROR:", err)
|
||||
}
|
||||
}
|
||||
errChn <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-errChn:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case <-timer.C:
|
||||
return nil, fmt.Errorf("ERROR connect to %s timeout", remoteAddr)
|
||||
}
|
||||
f := os.NewFile(uintptr(s), "TProxyTCPClient")
|
||||
client, err := net.FileConn(f)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Println("ERROR os.NewFile", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if Debug {
|
||||
fmt.Println("FINISHED Creating net.coo from:", client.LocalAddr().String(), "to:", client.RemoteAddr().String())
|
||||
}
|
||||
return client, err
|
||||
}
|
||||
30
core/tproxy/tproxy_setup.sh
Normal file
30
core/tproxy/tproxy_setup.sh
Normal file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
SOURCE_BIND_IP="127.0.1.1"
|
||||
|
||||
echo 0 > /proc/sys/net/ipv4/conf/lo/rp_filter
|
||||
echo 2 > /proc/sys/net/ipv4/conf/default/rp_filter
|
||||
echo 2 > /proc/sys/net/ipv4/conf/all/rp_filter
|
||||
echo 1 > /proc/sys/net/ipv4/conf/all/send_redirects
|
||||
echo 1 > /proc/sys/net/ipv4/conf/all/forwarding
|
||||
echo 1 > /proc/sys/net/ipv4/ip_forward
|
||||
|
||||
# 本地的话,貌似这段不需要
|
||||
# iptables -t mangle -N DIVERT >/dev/null 2>&1
|
||||
# iptables -t mangle -F DIVERT
|
||||
# iptables -t mangle -D PREROUTING -p tcp -m socket -j DIVERT >/dev/null 2>&1
|
||||
# iptables -t mangle -A PREROUTING -p tcp -m socket -j DIVERT
|
||||
# iptables -t mangle -A DIVERT -j MARK --set-mark 1
|
||||
# iptables -t mangle -A DIVERT -j ACCEPT
|
||||
|
||||
ip rule del fwmark 1 lookup 100
|
||||
ip rule add fwmark 1 lookup 100
|
||||
ip route del local 0.0.0.0/0 dev lo table 100
|
||||
ip route add local 0.0.0.0/0 dev lo table 100
|
||||
|
||||
ip rule del from ${SOURCE_BIND_IP} table 101
|
||||
ip rule add from ${SOURCE_BIND_IP} table 101
|
||||
ip route del default via 127.0.0.1 dev lo table 101
|
||||
ip route add default via 127.0.0.1 dev lo table 101
|
||||
|
||||
ip route flush cache
|
||||
ip ro flush cache
|
||||
Reference in New Issue
Block a user