diff --git a/services/sps/ssudp.go b/services/sps/ssudp.go new file mode 100644 index 0000000..7f75cfa --- /dev/null +++ b/services/sps/ssudp.go @@ -0,0 +1,161 @@ +package sps + +import ( + "bytes" + "fmt" + "net" + "runtime/debug" + "time" + + "bitbucket.org/snail/proxy/utils" + goaes "bitbucket.org/snail/proxy/utils/aes" + "bitbucket.org/snail/proxy/utils/socks" +) + +func (s *SPS) RunSSUDP(addr string) (err error) { + a, _ := net.ResolveUDPAddr("udp", addr) + listener, err := net.ListenUDP("udp", a) + if err != nil { + s.log.Printf("ss udp bind error %s", err) + return + } + s.log.Printf("ss udp on %s", listener.LocalAddr()) + s.udpRelatedPacketConns.Set(addr, listener) + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("udp local->out io copy crashed:\n%s\n%s", e, string(debug.Stack())) + } + }() + for { + buf := utils.LeakyBuffer.Get() + defer utils.LeakyBuffer.Put(buf) + n, srcAddr, err := listener.ReadFrom(buf) + if err != nil { + s.log.Printf("read from client error %s", err) + if utils.IsNetClosedErr(err) { + return + } + continue + } + var ( + inconnRemoteAddr = srcAddr.String() + outUDPConn *net.UDPConn + outconn net.Conn + outconnLocalAddr string + destAddr *net.UDPAddr + clean = func(msg, err string) { + raddr := "" + if outUDPConn != nil { + raddr = outUDPConn.RemoteAddr().String() + outUDPConn.Close() + } + if msg != "" { + if raddr != "" { + s.log.Printf("%s , %s , %s -> %s", msg, err, inconnRemoteAddr, raddr) + } else { + s.log.Printf("%s , %s , from : %s", msg, err, inconnRemoteAddr) + } + } + s.userConns.Remove(inconnRemoteAddr) + if outconn != nil { + outconn.Close() + } + if outconnLocalAddr != "" { + s.userConns.Remove(outconnLocalAddr) + } + } + ) + defer clean("", "") + + raw := new(bytes.Buffer) + raw.Write([]byte{0x00, 0x00, 0x00}) + raw.Write(s.localCipher.Decrypt(buf[:n])) + socksPacket := socks.NewPacketUDP() + err = socksPacket.Parse(raw.Bytes()) + raw = nil + if err != nil { + s.log.Printf("udp parse error %s", err) + return + } + + if v, ok := s.udpRelatedPacketConns.Get(inconnRemoteAddr); !ok { + //socks client + lbAddr := s.lb.Select(inconnRemoteAddr, *s.cfg.LoadBalanceOnlyHA) + outconn, err := s.GetParentConn(lbAddr) + if err != nil { + clean("connnect fail", fmt.Sprintf("%s", err)) + return + } + + client, err := s.HandshakeSocksParent(&outconn, "udp", socksPacket.Addr(), socks.Auth{}, true) + if err != nil { + clean("handshake fail", fmt.Sprintf("%s", err)) + return + } + + outconnLocalAddr = outconn.LocalAddr().String() + s.userConns.Set(outconnLocalAddr, &outconn) + go func() { + defer func() { + if e := recover(); e != nil { + s.log.Printf("udp related parent tcp conn read crashed:\n%s\n%s", e, string(debug.Stack())) + } + }() + buf := make([]byte, 1) + outconn.SetReadDeadline(time.Time{}) + if _, err := outconn.Read(buf); err != nil { + clean("udp parent tcp conn disconnected", fmt.Sprintf("%s", err)) + } + }() + destAddr, _ = net.ResolveUDPAddr("udp", client.UDPAddr) + localZeroAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + outUDPConn, err = net.DialUDP("udp", localZeroAddr, destAddr) + if err != nil { + s.log.Printf("create out udp conn fail , %s , from : %s", err, srcAddr) + return + } + s.udpRelatedPacketConns.Set(srcAddr.String(), outUDPConn) + utils.UDPCopy(listener, outUDPConn, srcAddr, time.Second*5, func(data []byte) []byte { + //forward to local + var v []byte + //convert parent data to raw + if len(s.udpParentKey) > 0 { + v, err = goaes.Decrypt(s.udpParentKey, data) + if err != nil { + s.log.Printf("udp outconn parse packet fail, %s", err.Error()) + return []byte{} + } + } else { + v = data + } + return s.localCipher.Encrypt(v[3:]) + }, func(err interface{}) { + s.udpRelatedPacketConns.Remove(srcAddr.String()) + if err != nil { + s.log.Printf("udp out->local io copy crashed:\n%s\n%s", err, string(debug.Stack())) + } + }) + } else { + outUDPConn = v.(*net.UDPConn) + } + //forward to parent + //p is raw, now convert it to parent + var v []byte + if len(s.udpParentKey) > 0 { + v, _ = goaes.Encrypt(s.udpParentKey, socksPacket.Bytes()) + } else { + v = socksPacket.Bytes() + } + _, err = outUDPConn.Write(v) + socksPacket = socks.PacketUDP{} + if err != nil { + if utils.IsNetClosedErr(err) { + return + } + s.log.Printf("send out udp data fail , %s , from : %s", err, srcAddr) + } + } + }() + return +} diff --git a/utils/datasize/datasize.go b/utils/datasize/datasize.go new file mode 100644 index 0000000..b6beef0 --- /dev/null +++ b/utils/datasize/datasize.go @@ -0,0 +1,235 @@ +package datasize + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +type ByteSize uint64 + +const ( + B ByteSize = 1 + KB = B << 10 + MB = KB << 10 + GB = MB << 10 + TB = GB << 10 + PB = TB << 10 + EB = PB << 10 + + fnUnmarshalText string = "UnmarshalText" + maxUint64 uint64 = (1 << 64) - 1 + cutoff uint64 = maxUint64 / 10 +) + +var ErrBits = errors.New("unit with capital unit prefix and lower case unit (b) - bits, not bytes ") + +var defaultDatasize ByteSize + +func Parse(s string) (bytes uint64, err error) { + err = defaultDatasize.UnmarshalText([]byte(s)) + if err != nil { + return + } + bytes = defaultDatasize.Bytes() + return +} +func HumanSize(bytes uint64) (s string, err error) { + err = defaultDatasize.UnmarshalText([]byte(fmt.Sprintf("%d", bytes))) + if err != nil { + return + } + s = defaultDatasize.HumanReadable() + return +} +func (b ByteSize) Bytes() uint64 { + return uint64(b) +} + +func (b ByteSize) KBytes() float64 { + v := b / KB + r := b % KB + return float64(v) + float64(r)/float64(KB) +} + +func (b ByteSize) MBytes() float64 { + v := b / MB + r := b % MB + return float64(v) + float64(r)/float64(MB) +} + +func (b ByteSize) GBytes() float64 { + v := b / GB + r := b % GB + return float64(v) + float64(r)/float64(GB) +} + +func (b ByteSize) TBytes() float64 { + v := b / TB + r := b % TB + return float64(v) + float64(r)/float64(TB) +} + +func (b ByteSize) PBytes() float64 { + v := b / PB + r := b % PB + return float64(v) + float64(r)/float64(PB) +} + +func (b ByteSize) EBytes() float64 { + v := b / EB + r := b % EB + return float64(v) + float64(r)/float64(EB) +} + +func (b ByteSize) String() string { + switch { + case b == 0: + return fmt.Sprint("0B") + case b%EB == 0: + return fmt.Sprintf("%dEB", b/EB) + case b%PB == 0: + return fmt.Sprintf("%dPB", b/PB) + case b%TB == 0: + return fmt.Sprintf("%dTB", b/TB) + case b%GB == 0: + return fmt.Sprintf("%dGB", b/GB) + case b%MB == 0: + return fmt.Sprintf("%dMB", b/MB) + case b%KB == 0: + return fmt.Sprintf("%dKB", b/KB) + default: + return fmt.Sprintf("%dB", b) + } +} + +func (b ByteSize) HR() string { + return b.HumanReadable() +} + +func (b ByteSize) HumanReadable() string { + switch { + case b > EB: + return fmt.Sprintf("%.1f EB", b.EBytes()) + case b > PB: + return fmt.Sprintf("%.1f PB", b.PBytes()) + case b > TB: + return fmt.Sprintf("%.1f TB", b.TBytes()) + case b > GB: + return fmt.Sprintf("%.1f GB", b.GBytes()) + case b > MB: + return fmt.Sprintf("%.1f MB", b.MBytes()) + case b > KB: + return fmt.Sprintf("%.1f KB", b.KBytes()) + default: + return fmt.Sprintf("%d B", b) + } +} + +func (b ByteSize) MarshalText() ([]byte, error) { + return []byte(b.String()), nil +} + +func (b *ByteSize) UnmarshalText(t []byte) error { + var val uint64 + var unit string + + // copy for error message + t0 := t + + var c byte + var i int + +ParseLoop: + for i < len(t) { + c = t[i] + switch { + case '0' <= c && c <= '9': + if val > cutoff { + goto Overflow + } + + c = c - '0' + val *= 10 + + if val > val+uint64(c) { + // val+v overflows + goto Overflow + } + val += uint64(c) + i++ + + default: + if i == 0 { + goto SyntaxError + } + break ParseLoop + } + } + + unit = strings.TrimSpace(string(t[i:])) + switch unit { + case "Kb", "Mb", "Gb", "Tb", "Pb", "Eb": + goto BitsError + } + unit = strings.ToLower(unit) + switch unit { + case "", "b", "byte": + // do nothing - already in bytes + + case "k", "kb", "kilo", "kilobyte", "kilobytes": + if val > maxUint64/uint64(KB) { + goto Overflow + } + val *= uint64(KB) + + case "m", "mb", "mega", "megabyte", "megabytes": + if val > maxUint64/uint64(MB) { + goto Overflow + } + val *= uint64(MB) + + case "g", "gb", "giga", "gigabyte", "gigabytes": + if val > maxUint64/uint64(GB) { + goto Overflow + } + val *= uint64(GB) + + case "t", "tb", "tera", "terabyte", "terabytes": + if val > maxUint64/uint64(TB) { + goto Overflow + } + val *= uint64(TB) + + case "p", "pb", "peta", "petabyte", "petabytes": + if val > maxUint64/uint64(PB) { + goto Overflow + } + val *= uint64(PB) + + case "E", "EB", "e", "eb", "eB": + if val > maxUint64/uint64(EB) { + goto Overflow + } + val *= uint64(EB) + + default: + goto SyntaxError + } + + *b = ByteSize(val) + return nil + +Overflow: + *b = ByteSize(maxUint64) + return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrRange} + +SyntaxError: + *b = 0 + return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrSyntax} + +BitsError: + *b = 0 + return &strconv.NumError{fnUnmarshalText, string(t0), ErrBits} +} diff --git a/utils/dnsx/resolver.go b/utils/dnsx/resolver.go new file mode 100644 index 0000000..f1a6514 --- /dev/null +++ b/utils/dnsx/resolver.go @@ -0,0 +1,114 @@ +package dnsx + +import ( + "fmt" + logger "log" + "net" + "strings" + "time" + + "bitbucket.org/snail/proxy/utils/mapx" + dns "github.com/miekg/dns" +) + +type DomainResolver struct { + ttl int + dnsAddrress string + data mapx.ConcurrentMap + log *logger.Logger +} +type DomainResolverItem struct { + ip string + domain string + expiredAt int64 +} + +func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver { + return DomainResolver{ + ttl: ttl, + dnsAddrress: dnsAddrress, + data: mapx.NewConcurrentMap(), + log: log, + } +} +func (a *DomainResolver) DnsAddress() (address string) { + address = a.dnsAddrress + return +} +func (a *DomainResolver) MustResolve(address string) (ip string) { + ip, _ = a.Resolve(address) + return +} +func (a *DomainResolver) Resolve(address string) (ip string, err error) { + domain := address + port := "" + fromCache := "false" + defer func() { + if port != "" { + ip = net.JoinHostPort(ip, port) + } + a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache) + //a.PrintData() + }() + if strings.Contains(domain, ":") { + domain, port, err = net.SplitHostPort(domain) + if err != nil { + return + } + } + if net.ParseIP(domain) != nil { + ip = domain + fromCache = "ip ignore" + return + } + item, ok := a.data.Get(domain) + if ok { + //log.Println("find ", domain) + if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() { + ip = (*item.(*DomainResolverItem)).ip + fromCache = "true" + //log.Println("from cache ", domain) + return + } + } else { + item = &DomainResolverItem{ + domain: domain, + } + + } + c := new(dns.Client) + c.DialTimeout = time.Millisecond * 5000 + c.ReadTimeout = time.Millisecond * 5000 + c.WriteTimeout = time.Millisecond * 5000 + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + r, _, err := c.Exchange(m, a.dnsAddrress) + if r == nil { + return + } + if r.Rcode != dns.RcodeSuccess { + err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress) + return + } + for _, answer := range r.Answer { + if answer.Header().Rrtype == dns.TypeA { + info := strings.Fields(answer.String()) + if len(info) >= 5 { + ip = info[4] + _item := item.(*DomainResolverItem) + (*_item).expiredAt = time.Now().Unix() + int64(a.ttl) + (*_item).ip = ip + a.data.Set(domain, item) + return + } + } + } + return +} +func (a *DomainResolver) PrintData() { + for k, item := range a.data.Items() { + d := item.(*DomainResolverItem) + a.log.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt) + } +} diff --git a/utils/iolimiter/iolimiter.go b/utils/iolimiter/iolimiter.go new file mode 100644 index 0000000..8bc478f --- /dev/null +++ b/utils/iolimiter/iolimiter.go @@ -0,0 +1,178 @@ +package iolimiter + +import ( + "context" + "io" + "net" + "time" + + "golang.org/x/time/rate" +) + +const burstLimit = 1000 * 1000 * 1000 + +type Reader struct { + r io.Reader + limiter *rate.Limiter + ctx context.Context +} + +type Writer struct { + w io.Writer + limiter *rate.Limiter + ctx context.Context +} + +type conn struct { + net.Conn + r io.Reader + w io.Writer + readLimiter *rate.Limiter + writeLimiter *rate.Limiter + ctx context.Context +} + +//NewtRateLimitConn sets rate limit (bytes/sec) to the Conn read and write. +func NewtConn(c net.Conn, bytesPerSec float64) net.Conn { + s := &conn{ + Conn: c, + r: c, + w: c, + ctx: context.Background(), + } + s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit) + s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst + s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit) + s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst + return s +} + +//NewtRateLimitReaderConn sets rate limit (bytes/sec) to the Conn read. +func NewReaderConn(c net.Conn, bytesPerSec float64) net.Conn { + s := &conn{ + Conn: c, + r: c, + w: c, + ctx: context.Background(), + } + s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit) + s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst + return s +} + +//NewtRateLimitWriterConn sets rate limit (bytes/sec) to the Conn write. +func NewWriterConn(c net.Conn, bytesPerSec float64) net.Conn { + s := &conn{ + Conn: c, + r: c, + w: c, + ctx: context.Background(), + } + s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit) + s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst + return s +} + +// Read reads bytes into p. +func (s *conn) Read(p []byte) (int, error) { + if s.readLimiter == nil { + return s.r.Read(p) + } + n, err := s.r.Read(p) + if err != nil { + return n, err + } + if err := s.readLimiter.WaitN(s.ctx, n); err != nil { + return n, err + } + return n, nil +} + +// Write writes bytes from p. +func (s *conn) Write(p []byte) (int, error) { + if s.writeLimiter == nil { + return s.w.Write(p) + } + n, err := s.w.Write(p) + if err != nil { + return n, err + } + if err := s.writeLimiter.WaitN(s.ctx, n); err != nil { + return n, err + } + return n, err +} + +// NewReader returns a reader that implements io.Reader with rate limiting. +func NewReader(r io.Reader) *Reader { + return &Reader{ + r: r, + ctx: context.Background(), + } +} + +// NewReaderWithContext returns a reader that implements io.Reader with rate limiting. +func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader { + return &Reader{ + r: r, + ctx: ctx, + } +} + +// NewWriter returns a writer that implements io.Writer with rate limiting. +func NewWriter(w io.Writer) *Writer { + return &Writer{ + w: w, + ctx: context.Background(), + } +} + +// NewWriterWithContext returns a writer that implements io.Writer with rate limiting. +func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer { + return &Writer{ + w: w, + ctx: ctx, + } +} + +// SetRateLimit sets rate limit (bytes/sec) to the reader. +func (s *Reader) SetRateLimit(bytesPerSec float64) { + s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit) + s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst +} + +// Read reads bytes into p. +func (s *Reader) Read(p []byte) (int, error) { + if s.limiter == nil { + return s.r.Read(p) + } + n, err := s.r.Read(p) + if err != nil { + return n, err + } + if err := s.limiter.WaitN(s.ctx, n); err != nil { + return n, err + } + return n, nil +} + +// SetRateLimit sets rate limit (bytes/sec) to the writer. +func (s *Writer) SetRateLimit(bytesPerSec float64) { + s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit) + s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst +} + +// Write writes bytes from p. +func (s *Writer) Write(p []byte) (int, error) { + if s.limiter == nil { + return s.w.Write(p) + } + n, err := s.w.Write(p) + if err != nil { + return n, err + } + if err := s.limiter.WaitN(s.ctx, n); err != nil { + return n, err + } + return n, err +} diff --git a/utils/lb/backend.go b/utils/lb/backend.go new file mode 100644 index 0000000..90af686 --- /dev/null +++ b/utils/lb/backend.go @@ -0,0 +1,203 @@ +package lb + +import ( + "errors" + "log" + "net" + "sync" + "time" + + "bitbucket.org/snail/proxy/utils/dnsx" +) + +// BackendConfig it's the configuration loaded +type BackendConfig struct { + Address string + + ActiveAfter int + InactiveAfter int + Weight int + + Timeout time.Duration + RetryTime time.Duration + + IsMuxCheck bool + ConnFactory func(address string, timeout time.Duration) (net.Conn, error) +} +type BackendsConfig []*BackendConfig + +// BackendControl keep the control data +type BackendControl struct { + Failed bool // The last request failed + Active bool + + InactiveTries int + ActiveTries int + + Connections int + + ConnectUsedMillisecond int + + isStop bool +} + +// Backend structure +type Backend struct { + BackendConfig + BackendControl + sync.RWMutex + log *log.Logger + dr *dnsx.DomainResolver +} + +type Backends []*Backend + +func NewBackend(backendConfig BackendConfig, dr *dnsx.DomainResolver, log *log.Logger) (*Backend, error) { + + if backendConfig.Address == "" { + return nil, errors.New("Address rquired") + } + if backendConfig.ActiveAfter == 0 { + backendConfig.ActiveAfter = 2 + } + if backendConfig.InactiveAfter == 0 { + backendConfig.InactiveAfter = 3 + } + if backendConfig.Weight == 0 { + backendConfig.Weight = 1 + } + if backendConfig.Timeout == 0 { + backendConfig.Timeout = time.Millisecond * 1500 + } + if backendConfig.RetryTime == 0 { + backendConfig.RetryTime = time.Millisecond * 2000 + } + return &Backend{ + dr: dr, + log: log, + BackendConfig: backendConfig, + BackendControl: BackendControl{ + Failed: true, + Active: false, + InactiveTries: 0, + ActiveTries: 0, + Connections: 0, + ConnectUsedMillisecond: 0, + isStop: false, + }, + }, nil +} +func (b *Backend) StopHeartCheck() { + b.isStop = true +} + +func (b *Backend) IncreasConns() { + b.RWMutex.Lock() + b.Connections++ + b.RWMutex.Unlock() +} + +func (b *Backend) DecreaseConns() { + b.RWMutex.Lock() + b.Connections-- + b.RWMutex.Unlock() +} + +func (b *Backend) StartHeartCheck() { + if b.IsMuxCheck { + b.startMuxHeartCheck() + } else { + b.startTCPHeartCheck() + } +} +func (b *Backend) startMuxHeartCheck() { + go func() { + for { + if b.isStop { + return + } + var c net.Conn + var err error + start := time.Now().UnixNano() / int64(time.Microsecond) + c, err = b.getConn() + b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start) + if err != nil { + b.Active = false + time.Sleep(time.Second * 2) + continue + } else { + b.Active = true + } + for { + buf := make([]byte, 1) + c.Read(buf) + buf = nil + break + } + b.Active = false + } + }() +} + +// Monitoring the backend +func (b *Backend) startTCPHeartCheck() { + go func() { + for { + if b.isStop { + return + } + var c net.Conn + var err error + start := time.Now().UnixNano() / int64(time.Microsecond) + c, err = b.getConn() + b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start) + if err == nil { + c.Close() + } + if err != nil { + b.RWMutex.Lock() + // Max tries before consider inactive + if b.InactiveTries >= b.InactiveAfter { + //b.log.Printf("Backend inactive [%s]", b.Address) + b.Active = false + b.ActiveTries = 0 + } else { + // Ok that guy it's out of the game + b.Failed = true + b.InactiveTries++ + //b.log.Printf("Error to check address [%s] tries [%d]", b.Address, b.InactiveTries) + } + b.RWMutex.Unlock() + } else { + + // Ok, let's keep working boys + b.RWMutex.Lock() + if b.ActiveTries >= b.ActiveAfter { + if b.Failed { + //log.Printf("Backend active [%s]", b.Address) + } + b.Failed = false + b.Active = true + b.InactiveTries = 0 + } else { + b.ActiveTries++ + } + b.RWMutex.Unlock() + } + time.Sleep(b.RetryTime) + } + }() +} +func (b *Backend) getConn() (conn net.Conn, err error) { + address := b.Address + if b.dr != nil && b.dr.DnsAddress() != "" { + address, err = b.dr.Resolve(b.Address) + if err != nil { + b.log.Printf("dns error %s , ERR:%s", b.Address, err) + } + } + if b.ConnFactory != nil { + return b.ConnFactory(address, b.Timeout) + } + return net.DialTimeout("tcp", address, b.Timeout) +} diff --git a/utils/lb/lb.go b/utils/lb/lb.go new file mode 100644 index 0000000..f1ae994 --- /dev/null +++ b/utils/lb/lb.go @@ -0,0 +1,687 @@ +package lb + +import ( + "crypto/md5" + "log" + "net" + "sync" + + "bitbucket.org/snail/proxy/utils/dnsx" +) + +const ( + SELECT_ROUNDROBIN = iota + SELECT_LEASTCONN + SELECT_HASH + SELECT_WEITHT + SELECT_LEASTTIME +) + +type Selector interface { + Select(srcAddr string) (addr string) + SelectBackend(srcAddr string) (b *Backend) + IncreasConns(addr string) + DecreaseConns(addr string) + Stop() + Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) + IsActive() bool + ActiveCount() (count int) + Backends() (bs []*Backend) +} + +type Group struct { + selector *Selector + log *log.Logger + dr *dnsx.DomainResolver + lock *sync.Mutex + last *Backend + debug bool +} + +func NewGroup(selectType int, configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger, debug bool) Group { + bks := []*Backend{} + for _, c := range configs { + b, _ := NewBackend(*c, dr, log) + bks = append(bks, b) + } + if len(bks) > 1 { + for _, b := range bks { + b.StartHeartCheck() + } + } + var s Selector + switch selectType { + case SELECT_ROUNDROBIN: + s = NewRoundRobin(bks, log, debug) + case SELECT_LEASTCONN: + s = NewLeastConn(bks, log, debug) + case SELECT_HASH: + s = NewHash(bks, log, debug) + case SELECT_WEITHT: + s = NewWeight(bks, log, debug) + case SELECT_LEASTTIME: + s = NewLeastTime(bks, log, debug) + } + return Group{ + selector: &s, + log: log, + dr: dr, + lock: &sync.Mutex{}, + debug: debug, + } +} +func (g *Group) Select(srcAddr string, onlyHa bool) (addr string) { + if onlyHa { + g.lock.Lock() + defer g.lock.Unlock() + if g.last != nil && (g.last.Active || g.last.ConnectUsedMillisecond == 0) { + if g.debug { + g.log.Printf("############ choosed %s from lastest ############", g.last.Address) + printDebug(true, g.log, nil, srcAddr, (*g.selector).Backends()) + } + return g.last.Address + } + g.last = (*g.selector).SelectBackend(srcAddr) + if !g.last.Active && g.last.ConnectUsedMillisecond > 0 { + g.log.Printf("###warn### lb selected empty , return default , for : %s", srcAddr) + } + return g.last.Address + } + b := (*g.selector).SelectBackend(srcAddr) + return b.Address + +} +func (g *Group) IncreasConns(addr string) { + (*g.selector).IncreasConns(addr) +} +func (g *Group) DecreaseConns(addr string) { + (*g.selector).DecreaseConns(addr) +} +func (g *Group) Stop() { + if g.selector != nil { + (*g.selector).Stop() + } +} +func (g *Group) IsActive() bool { + return (*g.selector).IsActive() +} +func (g *Group) ActiveCount() (count int) { + return (*g.selector).ActiveCount() +} +func (g *Group) Reset(addrs []string) { + bks := (*g.selector).Backends() + if len(bks) == 0 { + return + } + cfg := bks[0].BackendConfig + configs := BackendsConfig{} + for _, addr := range addrs { + c := cfg + c.Address = addr + configs = append(configs, &c) + } + (*g.selector).Reset(configs, g.dr, g.log) +} +func (g *Group) Backends() []*Backend { + return (*g.selector).Backends() +} + +//########################RoundRobin########################## +type RoundRobin struct { + sync.Mutex + backendIndex int + backends Backends + log *log.Logger + debug bool +} + +func NewRoundRobin(backends Backends, log *log.Logger, debug bool) Selector { + return &RoundRobin{ + backends: backends, + log: log, + debug: debug, + } + +} +func (r *RoundRobin) Select(srcAddr string) (addr string) { + return r.SelectBackend(srcAddr).Address +} +func (r *RoundRobin) SelectBackend(srcAddr string) (b *Backend) { + r.Lock() + defer r.Unlock() + defer func() { + printDebug(r.debug, r.log, b, srcAddr, r.backends) + }() + if len(r.backends) == 0 { + return + } + if len(r.backends) == 1 { + return r.backends[0] + } +RETRY: + found := false + for _, b := range r.backends { + if b.Active { + found = true + break + } + } + if !found { + return r.backends[0] + } + r.backendIndex++ + if r.backendIndex > len(r.backends)-1 { + r.backendIndex = 0 + } + if !r.backends[r.backendIndex].Active { + goto RETRY + } + return r.backends[r.backendIndex] +} +func (r *RoundRobin) IncreasConns(addr string) { + +} +func (r *RoundRobin) DecreaseConns(addr string) { + +} +func (r *RoundRobin) Stop() { + for _, b := range r.backends { + b.StopHeartCheck() + } +} +func (r *RoundRobin) Backends() []*Backend { + return r.backends +} +func (r *RoundRobin) IsActive() bool { + for _, b := range r.backends { + if b.Active { + return true + } + } + return false +} +func (r *RoundRobin) ActiveCount() (count int) { + for _, b := range r.backends { + if b.Active { + count++ + } + } + return +} +func (r *RoundRobin) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) { + r.Lock() + defer r.Unlock() + r.Stop() + bks := []*Backend{} + for _, c := range configs { + b, _ := NewBackend(*c, dr, log) + bks = append(bks, b) + } + if len(bks) > 1 { + for _, b := range bks { + b.StartHeartCheck() + } + } + r.backends = bks +} + +//########################LeastConn########################## + +type LeastConn struct { + sync.Mutex + backends Backends + log *log.Logger + debug bool +} + +func NewLeastConn(backends []*Backend, log *log.Logger, debug bool) Selector { + lc := LeastConn{ + backends: backends, + log: log, + debug: debug, + } + return &lc +} + +func (lc *LeastConn) Select(srcAddr string) (addr string) { + return lc.SelectBackend(srcAddr).Address +} +func (lc *LeastConn) SelectBackend(srcAddr string) (b *Backend) { + lc.Lock() + defer lc.Unlock() + defer func() { + printDebug(lc.debug, lc.log, b, srcAddr, lc.backends) + }() + if len(lc.backends) == 0 { + return + } + if len(lc.backends) == 1 { + return lc.backends[0] + } + found := false + for _, b := range lc.backends { + if b.Active { + found = true + break + } + } + if !found { + return lc.backends[0] + } + min := lc.backends[0].Connections + index := 0 + for i, b := range lc.backends { + if b.Active { + min = b.Connections + index = i + break + } + } + for i, b := range lc.backends { + if b.Active && b.Connections <= min { + min = b.Connections + index = i + } + } + return lc.backends[index] +} +func (lc *LeastConn) IncreasConns(addr string) { + for _, a := range lc.backends { + if a.Address == addr { + a.IncreasConns() + return + } + } +} +func (lc *LeastConn) DecreaseConns(addr string) { + for _, a := range lc.backends { + if a.Address == addr { + a.DecreaseConns() + return + } + } +} +func (lc *LeastConn) Stop() { + for _, b := range lc.backends { + b.StopHeartCheck() + } +} +func (lc *LeastConn) IsActive() bool { + for _, b := range lc.backends { + if b.Active { + return true + } + } + return false +} +func (lc *LeastConn) ActiveCount() (count int) { + for _, b := range lc.backends { + if b.Active { + count++ + } + } + return +} +func (lc *LeastConn) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) { + lc.Lock() + defer lc.Unlock() + lc.Stop() + bks := []*Backend{} + for _, c := range configs { + b, _ := NewBackend(*c, dr, log) + bks = append(bks, b) + } + if len(bks) > 1 { + for _, b := range bks { + b.StartHeartCheck() + } + } + lc.backends = bks +} +func (lc *LeastConn) Backends() []*Backend { + return lc.backends +} + +//########################Hash########################## +type Hash struct { + sync.Mutex + backends Backends + log *log.Logger + debug bool +} + +func NewHash(backends Backends, log *log.Logger, debug bool) Selector { + return &Hash{ + backends: backends, + log: log, + debug: debug, + } +} +func (h *Hash) Select(srcAddr string) (addr string) { + return h.SelectBackend(srcAddr).Address +} +func (h *Hash) SelectBackend(srcAddr string) (b *Backend) { + h.Lock() + defer h.Unlock() + defer func() { + printDebug(h.debug, h.log, b, srcAddr, h.backends) + }() + if len(h.backends) == 0 { + return + } + if len(h.backends) == 1 { + return h.backends[0] + } + i := 0 + host, _, err := net.SplitHostPort(srcAddr) + if err != nil { + return + } + //porti, _ := strconv.Atoi(port) + //i += porti + for _, b := range md5.Sum([]byte(host)) { + i += int(b) + } +RETRY: + found := false + for _, b := range h.backends { + if b.Active { + found = true + break + } + } + if !found { + return h.backends[0] + } + k := i % len(h.backends) + if !h.backends[k].Active { + i++ + goto RETRY + } + return h.backends[k] +} +func (h *Hash) IncreasConns(addr string) { + +} +func (h *Hash) DecreaseConns(addr string) { + +} +func (h *Hash) Stop() { + for _, b := range h.backends { + b.StopHeartCheck() + } +} +func (h *Hash) IsActive() bool { + for _, b := range h.backends { + if b.Active { + return true + } + } + return false +} +func (h *Hash) ActiveCount() (count int) { + for _, b := range h.backends { + if b.Active { + count++ + } + } + return +} +func (h *Hash) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) { + h.Lock() + defer h.Unlock() + h.Stop() + bks := []*Backend{} + for _, c := range configs { + b, _ := NewBackend(*c, dr, log) + bks = append(bks, b) + } + if len(bks) > 1 { + for _, b := range bks { + b.StartHeartCheck() + } + } + h.backends = bks +} +func (h *Hash) Backends() []*Backend { + return h.backends +} + +//########################Weight########################## +type Weight struct { + sync.Mutex + backends Backends + log *log.Logger + debug bool +} + +func NewWeight(backends Backends, log *log.Logger, debug bool) Selector { + return &Weight{ + backends: backends, + log: log, + debug: debug, + } +} +func (w *Weight) Select(srcAddr string) (addr string) { + return w.SelectBackend(srcAddr).Address +} +func (w *Weight) SelectBackend(srcAddr string) (b *Backend) { + w.Lock() + defer w.Unlock() + defer func() { + printDebug(w.debug, w.log, b, srcAddr, w.backends) + }() + if len(w.backends) == 0 { + return + } + if len(w.backends) == 1 { + return w.backends[0] + } + + found := false + for _, b := range w.backends { + if b.Active { + found = true + break + } + } + if !found { + return w.backends[0] + } + + min := w.backends[0].Connections / w.backends[0].Weight + index := 0 + for i, b := range w.backends { + if b.Active { + min = b.Connections / b.Weight + index = i + break + } + } + for i, b := range w.backends { + if b.Active && b.Connections/b.Weight <= min { + min = b.Connections + index = i + } + } + return w.backends[index] +} +func (w *Weight) IncreasConns(addr string) { + w.Lock() + defer w.Unlock() + for _, a := range w.backends { + if a.Address == addr { + a.IncreasConns() + return + } + } +} +func (w *Weight) DecreaseConns(addr string) { + w.Lock() + defer w.Unlock() + for _, a := range w.backends { + if a.Address == addr { + a.DecreaseConns() + return + } + } +} +func (w *Weight) Stop() { + for _, b := range w.backends { + b.StopHeartCheck() + } +} +func (w *Weight) IsActive() bool { + for _, b := range w.backends { + if b.Active { + return true + } + } + return false +} +func (w *Weight) ActiveCount() (count int) { + for _, b := range w.backends { + if b.Active { + count++ + } + } + return +} +func (w *Weight) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) { + w.Lock() + defer w.Unlock() + w.Stop() + bks := []*Backend{} + for _, c := range configs { + b, _ := NewBackend(*c, dr, log) + bks = append(bks, b) + } + if len(bks) > 1 { + for _, b := range bks { + b.StartHeartCheck() + } + } + w.backends = bks +} +func (w *Weight) Backends() []*Backend { + return w.backends +} + +//########################LeastTime########################## + +type LeastTime struct { + sync.Mutex + backends Backends + log *log.Logger + debug bool +} + +func NewLeastTime(backends []*Backend, log *log.Logger, debug bool) Selector { + lt := LeastTime{ + backends: backends, + log: log, + debug: debug, + } + return < +} + +func (lt *LeastTime) Select(srcAddr string) (addr string) { + return lt.SelectBackend(srcAddr).Address +} +func (lt *LeastTime) SelectBackend(srcAddr string) (b *Backend) { + lt.Lock() + defer lt.Unlock() + defer func() { + printDebug(lt.debug, lt.log, b, srcAddr, lt.backends) + }() + if len(lt.backends) == 0 { + return + } + if len(lt.backends) == 1 { + return lt.backends[0] + } + found := false + for _, b := range lt.backends { + if b.Active { + found = true + break + } + } + if !found { + return lt.backends[0] + } + min := lt.backends[0].ConnectUsedMillisecond + index := 0 + for i, b := range lt.backends { + if b.Active { + min = b.ConnectUsedMillisecond + index = i + break + } + } + for i, b := range lt.backends { + if b.Active && b.ConnectUsedMillisecond > 0 && b.ConnectUsedMillisecond <= min { + min = b.ConnectUsedMillisecond + index = i + } + } + return lt.backends[index] +} +func (lt *LeastTime) IncreasConns(addr string) { + +} +func (lt *LeastTime) DecreaseConns(addr string) { + +} +func (lt *LeastTime) Stop() { + for _, b := range lt.backends { + b.StopHeartCheck() + } +} +func (lt *LeastTime) IsActive() bool { + for _, b := range lt.backends { + if b.Active { + return true + } + } + return false +} +func (lt *LeastTime) ActiveCount() (count int) { + for _, b := range lt.backends { + if b.Active { + count++ + } + } + return +} +func (lt *LeastTime) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) { + lt.Lock() + defer lt.Unlock() + lt.Stop() + bks := []*Backend{} + for _, c := range configs { + b, _ := NewBackend(*c, dr, log) + bks = append(bks, b) + } + if len(bks) > 1 { + for _, b := range bks { + b.StartHeartCheck() + } + } + lt.backends = bks +} +func (lt *LeastTime) Backends() []*Backend { + return lt.backends +} +func printDebug(isDebug bool, log *log.Logger, selected *Backend, srcAddr string, backends []*Backend) { + if isDebug { + log.Printf("############ LB start ############\n") + if selected != nil { + log.Printf("choosed %s for %s\n", selected.Address, srcAddr) + } + for _, v := range backends { + log.Printf("addr:%s,conns:%d,time:%d,weight:%d,active:%v\n", v.Address, v.Connections, v.ConnectUsedMillisecond, v.Weight, v.Active) + } + log.Printf("############ LB end ############\n") + } +} diff --git a/utils/map.go b/utils/mapx/map.go similarity index 99% rename from utils/map.go rename to utils/mapx/map.go index 8ec82cf..3c10df9 100644 --- a/utils/map.go +++ b/utils/mapx/map.go @@ -1,4 +1,4 @@ -package utils +package mapx import ( "encoding/json" diff --git a/utils/ss/conn.go b/utils/ss/conn.go new file mode 100644 index 0000000..83f0681 --- /dev/null +++ b/utils/ss/conn.go @@ -0,0 +1,193 @@ +package ss + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strconv" +) + +const ( + OneTimeAuthMask byte = 0x10 + AddrMask byte = 0xf +) + +type Conn struct { + net.Conn + *Cipher + readBuf []byte + writeBuf []byte + chunkId uint32 +} + +func NewConn(c net.Conn, cipher *Cipher) *Conn { + return &Conn{ + Conn: c, + Cipher: cipher, + readBuf: leakyBuf.Get(), + writeBuf: leakyBuf.Get()} +} + +func (c *Conn) Close() error { + leakyBuf.Put(c.readBuf) + leakyBuf.Put(c.writeBuf) + return c.Conn.Close() +} + +func RawAddr(addr string) (buf []byte, err error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("ss: address error %s %v", addr, err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("ss: invalid port %s", addr) + } + + hostLen := len(host) + l := 1 + 1 + hostLen + 2 // addrType + lenByte + address + port + buf = make([]byte, l) + buf[0] = 3 // 3 means the address is domain name + buf[1] = byte(hostLen) // host address length followed by host address + copy(buf[2:], host) + binary.BigEndian.PutUint16(buf[2+hostLen:2+hostLen+2], uint16(port)) + return +} + +// This is intended for use by users implementing a local socks proxy. +// rawaddr shoud contain part of the data in socks request, starting from the +// ATYP field. (Refer to rfc1928 for more information.) +func DialWithRawAddr(rawConn *net.Conn, rawaddr []byte, server string, cipher *Cipher) (c *Conn, err error) { + var conn net.Conn + if rawConn == nil { + conn, err = net.Dial("tcp", server) + } + if err != nil { + return + } + if rawConn != nil { + c = NewConn(*rawConn, cipher) + } else { + c = NewConn(conn, cipher) + } + if cipher.ota { + if c.enc == nil { + if _, err = c.initEncrypt(); err != nil { + return + } + } + // since we have initEncrypt, we must send iv manually + conn.Write(cipher.iv) + rawaddr[0] |= OneTimeAuthMask + rawaddr = otaConnectAuth(cipher.iv, cipher.key, rawaddr) + } + if _, err = c.write(rawaddr); err != nil { + c.Close() + return nil, err + } + return +} + +// addr should be in the form of host:port +func Dial(addr, server string, cipher *Cipher) (c *Conn, err error) { + ra, err := RawAddr(addr) + if err != nil { + return + } + return DialWithRawAddr(nil, ra, server, cipher) +} + +func (c *Conn) GetIv() (iv []byte) { + iv = make([]byte, len(c.iv)) + copy(iv, c.iv) + return +} + +func (c *Conn) GetKey() (key []byte) { + key = make([]byte, len(c.key)) + copy(key, c.key) + return +} + +func (c *Conn) IsOta() bool { + return c.ota +} + +func (c *Conn) GetAndIncrChunkId() (chunkId uint32) { + chunkId = c.chunkId + c.chunkId += 1 + return +} + +func (c *Conn) Read(b []byte) (n int, err error) { + if c.dec == nil { + iv := make([]byte, c.info.ivLen) + if _, err = io.ReadFull(c.Conn, iv); err != nil { + return + } + if err = c.initDecrypt(iv); err != nil { + return + } + if len(c.iv) == 0 { + c.iv = iv + } + } + + cipherData := c.readBuf + if len(b) > len(cipherData) { + cipherData = make([]byte, len(b)) + } else { + cipherData = cipherData[:len(b)] + } + + n, err = c.Conn.Read(cipherData) + if n > 0 { + c.decrypt(b[0:n], cipherData[0:n]) + } + return +} + +func (c *Conn) Write(b []byte) (n int, err error) { + nn := len(b) + if c.ota { + chunkId := c.GetAndIncrChunkId() + b = otaReqChunkAuth(c.iv, chunkId, b) + } + headerLen := len(b) - nn + + n, err = c.write(b) + // Make sure <= 0 <= len(b), where b is the slice passed in. + if n >= headerLen { + n -= headerLen + } + return +} + +func (c *Conn) write(b []byte) (n int, err error) { + var iv []byte + if c.enc == nil { + iv, err = c.initEncrypt() + if err != nil { + return + } + } + + cipherData := c.writeBuf + dataSize := len(b) + len(iv) + if dataSize > len(cipherData) { + cipherData = make([]byte, dataSize) + } else { + cipherData = cipherData[:dataSize] + } + + if iv != nil { + // Put initialization vector in buffer, do a single write to send both + // iv and data. + copy(cipherData, iv) + } + + c.encrypt(cipherData[len(iv):], b) + n, err = c.Conn.Write(cipherData) + return +} diff --git a/utils/ss/encrypt.go b/utils/ss/encrypt.go new file mode 100644 index 0000000..3a40d64 --- /dev/null +++ b/utils/ss/encrypt.go @@ -0,0 +1,301 @@ +package ss + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/md5" + "crypto/rand" + "crypto/rc4" + "encoding/binary" + "errors" + "io" + "strings" + + "github.com/Yawning/chacha20" + "golang.org/x/crypto/blowfish" + "golang.org/x/crypto/cast5" + "golang.org/x/crypto/salsa20/salsa" +) + +var errEmptyPassword = errors.New("empty key") + +func md5sum(d []byte) []byte { + h := md5.New() + h.Write(d) + return h.Sum(nil) +} + +func evpBytesToKey(password string, keyLen int) (key []byte) { + const md5Len = 16 + + cnt := (keyLen-1)/md5Len + 1 + m := make([]byte, cnt*md5Len) + copy(m, md5sum([]byte(password))) + + // Repeatedly call md5 until bytes generated is enough. + // Each call to md5 uses data: prev md5 sum + password. + d := make([]byte, md5Len+len(password)) + start := 0 + for i := 1; i < cnt; i++ { + start += md5Len + copy(d, m[start-md5Len:start]) + copy(d[md5Len:], password) + copy(m[start:], md5sum(d)) + } + return m[:keyLen] +} + +type DecOrEnc int + +const ( + Decrypt DecOrEnc = iota + Encrypt +) + +func newStream(block cipher.Block, err error, key, iv []byte, + doe DecOrEnc) (cipher.Stream, error) { + if err != nil { + return nil, err + } + if doe == Encrypt { + return cipher.NewCFBEncrypter(block, iv), nil + } else { + return cipher.NewCFBDecrypter(block, iv), nil + } +} + +func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := aes.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(block, iv), nil +} + +func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := des.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newBlowFishStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := blowfish.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newCast5Stream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) { + block, err := cast5.NewCipher(key) + return newStream(block, err, key, iv, doe) +} + +func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + h := md5.New() + h.Write(key) + h.Write(iv) + rc4key := h.Sum(nil) + + return rc4.NewCipher(rc4key) +} + +func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + return chacha20.NewCipher(key, iv) +} + +func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + return chacha20.NewCipher(key, iv) +} + +type salsaStreamCipher struct { + nonce [8]byte + key [32]byte + counter int +} + +func (c *salsaStreamCipher) XORKeyStream(dst, src []byte) { + var buf []byte + padLen := c.counter % 64 + dataSize := len(src) + padLen + if cap(dst) >= dataSize { + buf = dst[:dataSize] + } else if leakyBufSize >= dataSize { + buf = leakyBuf.Get() + defer leakyBuf.Put(buf) + buf = buf[:dataSize] + } else { + buf = make([]byte, dataSize) + } + + var subNonce [16]byte + copy(subNonce[:], c.nonce[:]) + binary.LittleEndian.PutUint64(subNonce[len(c.nonce):], uint64(c.counter/64)) + + // It's difficult to avoid data copy here. src or dst maybe slice from + // Conn.Read/Write, which can't have padding. + copy(buf[padLen:], src[:]) + salsa.XORKeyStream(buf, buf, &subNonce, &c.key) + copy(dst, buf[padLen:]) + + c.counter += len(src) +} + +func newSalsa20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) { + var c salsaStreamCipher + copy(c.nonce[:], iv[:8]) + copy(c.key[:], key[:32]) + return &c, nil +} + +type cipherInfo struct { + keyLen int + ivLen int + newStream func(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) +} + +var cipherMethod = map[string]*cipherInfo{ + "aes-128-cfb": {16, 16, newAESCFBStream}, + "aes-192-cfb": {24, 16, newAESCFBStream}, + "aes-256-cfb": {32, 16, newAESCFBStream}, + "aes-128-ctr": {16, 16, newAESCTRStream}, + "aes-192-ctr": {24, 16, newAESCTRStream}, + "aes-256-ctr": {32, 16, newAESCTRStream}, + "des-cfb": {8, 8, newDESStream}, + "bf-cfb": {16, 8, newBlowFishStream}, + "cast5-cfb": {16, 8, newCast5Stream}, + "rc4-md5": {16, 16, newRC4MD5Stream}, + "rc4-md5-6": {16, 6, newRC4MD5Stream}, + "chacha20": {32, 8, newChaCha20Stream}, + "chacha20-ietf": {32, 12, newChaCha20IETFStream}, + "salsa20": {32, 8, newSalsa20Stream}, +} + +func CheckCipherMethod(method string) error { + if method == "" { + method = "aes-256-cfb" + } + _, ok := cipherMethod[method] + if !ok { + return errors.New("Unsupported encryption method: " + method) + } + return nil +} + +type Cipher struct { + enc cipher.Stream + dec cipher.Stream + key []byte + info *cipherInfo + ota bool // one-time auth + iv []byte +} + +// NewCipher creates a cipher that can be used in Dial() etc. +// Use cipher.Copy() to create a new cipher with the same method and password +// to avoid the cost of repeated cipher initialization. +func NewCipher(method, password string) (c *Cipher, err error) { + if password == "" { + return nil, errEmptyPassword + } + var ota bool + if strings.HasSuffix(strings.ToLower(method), "-auth") { + method = method[:len(method)-5] // len("-auth") = 5 + ota = true + } else { + ota = false + } + mi, ok := cipherMethod[method] + if !ok { + return nil, errors.New("Unsupported encryption method: " + method) + } + + key := evpBytesToKey(password, mi.keyLen) + + c = &Cipher{key: key, info: mi} + + if err != nil { + return nil, err + } + c.ota = ota + return c, nil +} + +// Initializes the block cipher with CFB mode, returns IV. +func (c *Cipher) initEncrypt() (iv []byte, err error) { + if c.iv == nil { + iv = make([]byte, c.info.ivLen) + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + c.iv = iv + } else { + iv = c.iv + } + c.enc, err = c.info.newStream(c.key, iv, Encrypt) + return +} + +func (c *Cipher) initDecrypt(iv []byte) (err error) { + c.dec, err = c.info.newStream(c.key, iv, Decrypt) + return +} + +func (c *Cipher) encrypt(dst, src []byte) { + c.enc.XORKeyStream(dst, src) +} + +func (c *Cipher) decrypt(dst, src []byte) { + c.dec.XORKeyStream(dst, src) +} +func (c *Cipher) Encrypt(src []byte) (cipherData []byte) { + cipher := c.Copy() + iv, err := cipher.initEncrypt() + if err != nil { + return + } + packetLen := len(src) + len(iv) + cipherData = make([]byte, packetLen) + copy(cipherData, iv) + cipher.encrypt(cipherData[len(iv):], src) + return +} + +func (c *Cipher) Decrypt(src []byte) (data []byte) { + cipher := c.Copy() + if len(src) < c.info.ivLen { + return + } + iv := make([]byte, c.info.ivLen) + copy(iv, src[:c.info.ivLen]) + if err := cipher.initDecrypt(iv); err != nil { + return + } + data = make([]byte, len(src)-len(iv)) + cipher.decrypt(data[0:], src[c.info.ivLen:]) + return +} + +// Copy creates a new cipher at it's initial state. +func (c *Cipher) Copy() *Cipher { + // This optimization maybe not necessary. But without this function, we + // need to maintain a table cache for newTableCipher and use lock to + // protect concurrent access to that cache. + + // AES and DES ciphers does not return specific types, so it's difficult + // to create copy. But their initizliation time is less than 4000ns on my + // 2.26 GHz Intel Core 2 Duo processor. So no need to worry. + + // Currently, blow-fish and cast5 initialization cost is an order of + // maganitude slower than other ciphers. (I'm not sure whether this is + // because the current implementation is not highly optimized, or this is + // the nature of the algorithm.) + + nc := *c + nc.enc = nil + nc.dec = nil + nc.ota = c.ota + return &nc +} diff --git a/utils/ss/util.go b/utils/ss/util.go new file mode 100644 index 0000000..d4ae108 --- /dev/null +++ b/utils/ss/util.go @@ -0,0 +1,131 @@ +package ss + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "os" + "strconv" + + "bitbucket.org/snail/proxy/utils" +) + +const leakyBufSize = 4108 // data.len(2) + hmacsha1(10) + data(4096) +const maxNBuf = 2048 + +var leakyBuf = utils.NewLeakyBuf(maxNBuf, leakyBufSize) + +func IsFileExists(path string) (bool, error) { + stat, err := os.Stat(path) + if err == nil { + if stat.Mode()&os.ModeType == 0 { + return true, nil + } + return false, errors.New(path + " exists but is not regular file") + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func HmacSha1(key []byte, data []byte) []byte { + hmacSha1 := hmac.New(sha1.New, key) + hmacSha1.Write(data) + return hmacSha1.Sum(nil)[:10] +} + +func otaConnectAuth(iv, key, data []byte) []byte { + return append(data, HmacSha1(append(iv, key...), data)...) +} + +func otaReqChunkAuth(iv []byte, chunkId uint32, data []byte) []byte { + nb := make([]byte, 2) + binary.BigEndian.PutUint16(nb, uint16(len(data))) + chunkIdBytes := make([]byte, 4) + binary.BigEndian.PutUint32(chunkIdBytes, chunkId) + header := append(nb, HmacSha1(append(iv, chunkIdBytes...), data)...) + return append(header, data...) +} + +const ( + idType = 0 // address type index + idIP0 = 1 // ip addres start index + idDmLen = 1 // domain address length index + idDm0 = 2 // domain address start index + + typeIPv4 = 1 // type is ipv4 address + typeDm = 3 // type is domain address + typeIPv6 = 4 // type is ipv6 address + + lenIPv4 = net.IPv4len + 2 // ipv4 + 2port + lenIPv6 = net.IPv6len + 2 // ipv6 + 2port + lenDmBase = 2 // 1addrLen + 2port, plus addrLen + lenHmacSha1 = 10 +) + +func GetRequest(conn *Conn) (host string, err error) { + + // buf size should at least have the same size with the largest possible + // request size (when addrType is 3, domain name has at most 256 bytes) + // 1(addrType) + 1(lenByte) + 255(max length address) + 2(port) + 10(hmac-sha1) + buf := make([]byte, 269) + // read till we get possible domain length field + if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil { + return + } + + var reqStart, reqEnd int + addrType := buf[idType] + switch addrType & AddrMask { + case typeIPv4: + reqStart, reqEnd = idIP0, idIP0+lenIPv4 + case typeIPv6: + reqStart, reqEnd = idIP0, idIP0+lenIPv6 + case typeDm: + if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil { + return + } + reqStart, reqEnd = idDm0, idDm0+int(buf[idDmLen])+lenDmBase + default: + err = fmt.Errorf("addr type %d not supported", addrType&AddrMask) + return + } + + if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil { + return + } + + // Return string for typeIP is not most efficient, but browsers (Chrome, + // Safari, Firefox) all seems using typeDm exclusively. So this is not a + // big problem. + switch addrType & AddrMask { + case typeIPv4: + host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String() + case typeIPv6: + host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() + case typeDm: + host = string(buf[idDm0 : idDm0+int(buf[idDmLen])]) + } + // parse port + port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd]) + host = net.JoinHostPort(host, strconv.Itoa(int(port))) + + return +} + +type ClosedFlag struct { + flag bool +} + +func (flag *ClosedFlag) SetClosed() { + flag.flag = true +} + +func (flag *ClosedFlag) IsClosed() bool { + return flag.flag +}