merge enterprise

This commit is contained in:
arraykeys@gmail.com
2018-09-04 16:00:08 +08:00
parent 89be79b6c6
commit a993b1bb9d
10 changed files with 2204 additions and 1 deletions

161
services/sps/ssudp.go Normal file
View File

@ -0,0 +1,161 @@
package sps
import (
"bytes"
"fmt"
"net"
"runtime/debug"
"time"
"bitbucket.org/snail/proxy/utils"
goaes "bitbucket.org/snail/proxy/utils/aes"
"bitbucket.org/snail/proxy/utils/socks"
)
func (s *SPS) RunSSUDP(addr string) (err error) {
a, _ := net.ResolveUDPAddr("udp", addr)
listener, err := net.ListenUDP("udp", a)
if err != nil {
s.log.Printf("ss udp bind error %s", err)
return
}
s.log.Printf("ss udp on %s", listener.LocalAddr())
s.udpRelatedPacketConns.Set(addr, listener)
go func() {
defer func() {
if e := recover(); e != nil {
s.log.Printf("udp local->out io copy crashed:\n%s\n%s", e, string(debug.Stack()))
}
}()
for {
buf := utils.LeakyBuffer.Get()
defer utils.LeakyBuffer.Put(buf)
n, srcAddr, err := listener.ReadFrom(buf)
if err != nil {
s.log.Printf("read from client error %s", err)
if utils.IsNetClosedErr(err) {
return
}
continue
}
var (
inconnRemoteAddr = srcAddr.String()
outUDPConn *net.UDPConn
outconn net.Conn
outconnLocalAddr string
destAddr *net.UDPAddr
clean = func(msg, err string) {
raddr := ""
if outUDPConn != nil {
raddr = outUDPConn.RemoteAddr().String()
outUDPConn.Close()
}
if msg != "" {
if raddr != "" {
s.log.Printf("%s , %s , %s -> %s", msg, err, inconnRemoteAddr, raddr)
} else {
s.log.Printf("%s , %s , from : %s", msg, err, inconnRemoteAddr)
}
}
s.userConns.Remove(inconnRemoteAddr)
if outconn != nil {
outconn.Close()
}
if outconnLocalAddr != "" {
s.userConns.Remove(outconnLocalAddr)
}
}
)
defer clean("", "")
raw := new(bytes.Buffer)
raw.Write([]byte{0x00, 0x00, 0x00})
raw.Write(s.localCipher.Decrypt(buf[:n]))
socksPacket := socks.NewPacketUDP()
err = socksPacket.Parse(raw.Bytes())
raw = nil
if err != nil {
s.log.Printf("udp parse error %s", err)
return
}
if v, ok := s.udpRelatedPacketConns.Get(inconnRemoteAddr); !ok {
//socks client
lbAddr := s.lb.Select(inconnRemoteAddr, *s.cfg.LoadBalanceOnlyHA)
outconn, err := s.GetParentConn(lbAddr)
if err != nil {
clean("connnect fail", fmt.Sprintf("%s", err))
return
}
client, err := s.HandshakeSocksParent(&outconn, "udp", socksPacket.Addr(), socks.Auth{}, true)
if err != nil {
clean("handshake fail", fmt.Sprintf("%s", err))
return
}
outconnLocalAddr = outconn.LocalAddr().String()
s.userConns.Set(outconnLocalAddr, &outconn)
go func() {
defer func() {
if e := recover(); e != nil {
s.log.Printf("udp related parent tcp conn read crashed:\n%s\n%s", e, string(debug.Stack()))
}
}()
buf := make([]byte, 1)
outconn.SetReadDeadline(time.Time{})
if _, err := outconn.Read(buf); err != nil {
clean("udp parent tcp conn disconnected", fmt.Sprintf("%s", err))
}
}()
destAddr, _ = net.ResolveUDPAddr("udp", client.UDPAddr)
localZeroAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
outUDPConn, err = net.DialUDP("udp", localZeroAddr, destAddr)
if err != nil {
s.log.Printf("create out udp conn fail , %s , from : %s", err, srcAddr)
return
}
s.udpRelatedPacketConns.Set(srcAddr.String(), outUDPConn)
utils.UDPCopy(listener, outUDPConn, srcAddr, time.Second*5, func(data []byte) []byte {
//forward to local
var v []byte
//convert parent data to raw
if len(s.udpParentKey) > 0 {
v, err = goaes.Decrypt(s.udpParentKey, data)
if err != nil {
s.log.Printf("udp outconn parse packet fail, %s", err.Error())
return []byte{}
}
} else {
v = data
}
return s.localCipher.Encrypt(v[3:])
}, func(err interface{}) {
s.udpRelatedPacketConns.Remove(srcAddr.String())
if err != nil {
s.log.Printf("udp out->local io copy crashed:\n%s\n%s", err, string(debug.Stack()))
}
})
} else {
outUDPConn = v.(*net.UDPConn)
}
//forward to parent
//p is raw, now convert it to parent
var v []byte
if len(s.udpParentKey) > 0 {
v, _ = goaes.Encrypt(s.udpParentKey, socksPacket.Bytes())
} else {
v = socksPacket.Bytes()
}
_, err = outUDPConn.Write(v)
socksPacket = socks.PacketUDP{}
if err != nil {
if utils.IsNetClosedErr(err) {
return
}
s.log.Printf("send out udp data fail , %s , from : %s", err, srcAddr)
}
}
}()
return
}

235
utils/datasize/datasize.go Normal file
View File

@ -0,0 +1,235 @@
package datasize
import (
"errors"
"fmt"
"strconv"
"strings"
)
type ByteSize uint64
const (
B ByteSize = 1
KB = B << 10
MB = KB << 10
GB = MB << 10
TB = GB << 10
PB = TB << 10
EB = PB << 10
fnUnmarshalText string = "UnmarshalText"
maxUint64 uint64 = (1 << 64) - 1
cutoff uint64 = maxUint64 / 10
)
var ErrBits = errors.New("unit with capital unit prefix and lower case unit (b) - bits, not bytes ")
var defaultDatasize ByteSize
func Parse(s string) (bytes uint64, err error) {
err = defaultDatasize.UnmarshalText([]byte(s))
if err != nil {
return
}
bytes = defaultDatasize.Bytes()
return
}
func HumanSize(bytes uint64) (s string, err error) {
err = defaultDatasize.UnmarshalText([]byte(fmt.Sprintf("%d", bytes)))
if err != nil {
return
}
s = defaultDatasize.HumanReadable()
return
}
func (b ByteSize) Bytes() uint64 {
return uint64(b)
}
func (b ByteSize) KBytes() float64 {
v := b / KB
r := b % KB
return float64(v) + float64(r)/float64(KB)
}
func (b ByteSize) MBytes() float64 {
v := b / MB
r := b % MB
return float64(v) + float64(r)/float64(MB)
}
func (b ByteSize) GBytes() float64 {
v := b / GB
r := b % GB
return float64(v) + float64(r)/float64(GB)
}
func (b ByteSize) TBytes() float64 {
v := b / TB
r := b % TB
return float64(v) + float64(r)/float64(TB)
}
func (b ByteSize) PBytes() float64 {
v := b / PB
r := b % PB
return float64(v) + float64(r)/float64(PB)
}
func (b ByteSize) EBytes() float64 {
v := b / EB
r := b % EB
return float64(v) + float64(r)/float64(EB)
}
func (b ByteSize) String() string {
switch {
case b == 0:
return fmt.Sprint("0B")
case b%EB == 0:
return fmt.Sprintf("%dEB", b/EB)
case b%PB == 0:
return fmt.Sprintf("%dPB", b/PB)
case b%TB == 0:
return fmt.Sprintf("%dTB", b/TB)
case b%GB == 0:
return fmt.Sprintf("%dGB", b/GB)
case b%MB == 0:
return fmt.Sprintf("%dMB", b/MB)
case b%KB == 0:
return fmt.Sprintf("%dKB", b/KB)
default:
return fmt.Sprintf("%dB", b)
}
}
func (b ByteSize) HR() string {
return b.HumanReadable()
}
func (b ByteSize) HumanReadable() string {
switch {
case b > EB:
return fmt.Sprintf("%.1f EB", b.EBytes())
case b > PB:
return fmt.Sprintf("%.1f PB", b.PBytes())
case b > TB:
return fmt.Sprintf("%.1f TB", b.TBytes())
case b > GB:
return fmt.Sprintf("%.1f GB", b.GBytes())
case b > MB:
return fmt.Sprintf("%.1f MB", b.MBytes())
case b > KB:
return fmt.Sprintf("%.1f KB", b.KBytes())
default:
return fmt.Sprintf("%d B", b)
}
}
func (b ByteSize) MarshalText() ([]byte, error) {
return []byte(b.String()), nil
}
func (b *ByteSize) UnmarshalText(t []byte) error {
var val uint64
var unit string
// copy for error message
t0 := t
var c byte
var i int
ParseLoop:
for i < len(t) {
c = t[i]
switch {
case '0' <= c && c <= '9':
if val > cutoff {
goto Overflow
}
c = c - '0'
val *= 10
if val > val+uint64(c) {
// val+v overflows
goto Overflow
}
val += uint64(c)
i++
default:
if i == 0 {
goto SyntaxError
}
break ParseLoop
}
}
unit = strings.TrimSpace(string(t[i:]))
switch unit {
case "Kb", "Mb", "Gb", "Tb", "Pb", "Eb":
goto BitsError
}
unit = strings.ToLower(unit)
switch unit {
case "", "b", "byte":
// do nothing - already in bytes
case "k", "kb", "kilo", "kilobyte", "kilobytes":
if val > maxUint64/uint64(KB) {
goto Overflow
}
val *= uint64(KB)
case "m", "mb", "mega", "megabyte", "megabytes":
if val > maxUint64/uint64(MB) {
goto Overflow
}
val *= uint64(MB)
case "g", "gb", "giga", "gigabyte", "gigabytes":
if val > maxUint64/uint64(GB) {
goto Overflow
}
val *= uint64(GB)
case "t", "tb", "tera", "terabyte", "terabytes":
if val > maxUint64/uint64(TB) {
goto Overflow
}
val *= uint64(TB)
case "p", "pb", "peta", "petabyte", "petabytes":
if val > maxUint64/uint64(PB) {
goto Overflow
}
val *= uint64(PB)
case "E", "EB", "e", "eb", "eB":
if val > maxUint64/uint64(EB) {
goto Overflow
}
val *= uint64(EB)
default:
goto SyntaxError
}
*b = ByteSize(val)
return nil
Overflow:
*b = ByteSize(maxUint64)
return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrRange}
SyntaxError:
*b = 0
return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrSyntax}
BitsError:
*b = 0
return &strconv.NumError{fnUnmarshalText, string(t0), ErrBits}
}

114
utils/dnsx/resolver.go Normal file
View File

@ -0,0 +1,114 @@
package dnsx
import (
"fmt"
logger "log"
"net"
"strings"
"time"
"bitbucket.org/snail/proxy/utils/mapx"
dns "github.com/miekg/dns"
)
type DomainResolver struct {
ttl int
dnsAddrress string
data mapx.ConcurrentMap
log *logger.Logger
}
type DomainResolverItem struct {
ip string
domain string
expiredAt int64
}
func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver {
return DomainResolver{
ttl: ttl,
dnsAddrress: dnsAddrress,
data: mapx.NewConcurrentMap(),
log: log,
}
}
func (a *DomainResolver) DnsAddress() (address string) {
address = a.dnsAddrress
return
}
func (a *DomainResolver) MustResolve(address string) (ip string) {
ip, _ = a.Resolve(address)
return
}
func (a *DomainResolver) Resolve(address string) (ip string, err error) {
domain := address
port := ""
fromCache := "false"
defer func() {
if port != "" {
ip = net.JoinHostPort(ip, port)
}
a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache)
//a.PrintData()
}()
if strings.Contains(domain, ":") {
domain, port, err = net.SplitHostPort(domain)
if err != nil {
return
}
}
if net.ParseIP(domain) != nil {
ip = domain
fromCache = "ip ignore"
return
}
item, ok := a.data.Get(domain)
if ok {
//log.Println("find ", domain)
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
ip = (*item.(*DomainResolverItem)).ip
fromCache = "true"
//log.Println("from cache ", domain)
return
}
} else {
item = &DomainResolverItem{
domain: domain,
}
}
c := new(dns.Client)
c.DialTimeout = time.Millisecond * 5000
c.ReadTimeout = time.Millisecond * 5000
c.WriteTimeout = time.Millisecond * 5000
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
m.RecursionDesired = true
r, _, err := c.Exchange(m, a.dnsAddrress)
if r == nil {
return
}
if r.Rcode != dns.RcodeSuccess {
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
return
}
for _, answer := range r.Answer {
if answer.Header().Rrtype == dns.TypeA {
info := strings.Fields(answer.String())
if len(info) >= 5 {
ip = info[4]
_item := item.(*DomainResolverItem)
(*_item).expiredAt = time.Now().Unix() + int64(a.ttl)
(*_item).ip = ip
a.data.Set(domain, item)
return
}
}
}
return
}
func (a *DomainResolver) PrintData() {
for k, item := range a.data.Items() {
d := item.(*DomainResolverItem)
a.log.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt)
}
}

View File

@ -0,0 +1,178 @@
package iolimiter
import (
"context"
"io"
"net"
"time"
"golang.org/x/time/rate"
)
const burstLimit = 1000 * 1000 * 1000
type Reader struct {
r io.Reader
limiter *rate.Limiter
ctx context.Context
}
type Writer struct {
w io.Writer
limiter *rate.Limiter
ctx context.Context
}
type conn struct {
net.Conn
r io.Reader
w io.Writer
readLimiter *rate.Limiter
writeLimiter *rate.Limiter
ctx context.Context
}
//NewtRateLimitConn sets rate limit (bytes/sec) to the Conn read and write.
func NewtConn(c net.Conn, bytesPerSec float64) net.Conn {
s := &conn{
Conn: c,
r: c,
w: c,
ctx: context.Background(),
}
s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
return s
}
//NewtRateLimitReaderConn sets rate limit (bytes/sec) to the Conn read.
func NewReaderConn(c net.Conn, bytesPerSec float64) net.Conn {
s := &conn{
Conn: c,
r: c,
w: c,
ctx: context.Background(),
}
s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
return s
}
//NewtRateLimitWriterConn sets rate limit (bytes/sec) to the Conn write.
func NewWriterConn(c net.Conn, bytesPerSec float64) net.Conn {
s := &conn{
Conn: c,
r: c,
w: c,
ctx: context.Background(),
}
s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
return s
}
// Read reads bytes into p.
func (s *conn) Read(p []byte) (int, error) {
if s.readLimiter == nil {
return s.r.Read(p)
}
n, err := s.r.Read(p)
if err != nil {
return n, err
}
if err := s.readLimiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, nil
}
// Write writes bytes from p.
func (s *conn) Write(p []byte) (int, error) {
if s.writeLimiter == nil {
return s.w.Write(p)
}
n, err := s.w.Write(p)
if err != nil {
return n, err
}
if err := s.writeLimiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, err
}
// NewReader returns a reader that implements io.Reader with rate limiting.
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
ctx: context.Background(),
}
}
// NewReaderWithContext returns a reader that implements io.Reader with rate limiting.
func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader {
return &Reader{
r: r,
ctx: ctx,
}
}
// NewWriter returns a writer that implements io.Writer with rate limiting.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
ctx: context.Background(),
}
}
// NewWriterWithContext returns a writer that implements io.Writer with rate limiting.
func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer {
return &Writer{
w: w,
ctx: ctx,
}
}
// SetRateLimit sets rate limit (bytes/sec) to the reader.
func (s *Reader) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
}
// Read reads bytes into p.
func (s *Reader) Read(p []byte) (int, error) {
if s.limiter == nil {
return s.r.Read(p)
}
n, err := s.r.Read(p)
if err != nil {
return n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, nil
}
// SetRateLimit sets rate limit (bytes/sec) to the writer.
func (s *Writer) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
}
// Write writes bytes from p.
func (s *Writer) Write(p []byte) (int, error) {
if s.limiter == nil {
return s.w.Write(p)
}
n, err := s.w.Write(p)
if err != nil {
return n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, err
}

203
utils/lb/backend.go Normal file
View File

@ -0,0 +1,203 @@
package lb
import (
"errors"
"log"
"net"
"sync"
"time"
"bitbucket.org/snail/proxy/utils/dnsx"
)
// BackendConfig it's the configuration loaded
type BackendConfig struct {
Address string
ActiveAfter int
InactiveAfter int
Weight int
Timeout time.Duration
RetryTime time.Duration
IsMuxCheck bool
ConnFactory func(address string, timeout time.Duration) (net.Conn, error)
}
type BackendsConfig []*BackendConfig
// BackendControl keep the control data
type BackendControl struct {
Failed bool // The last request failed
Active bool
InactiveTries int
ActiveTries int
Connections int
ConnectUsedMillisecond int
isStop bool
}
// Backend structure
type Backend struct {
BackendConfig
BackendControl
sync.RWMutex
log *log.Logger
dr *dnsx.DomainResolver
}
type Backends []*Backend
func NewBackend(backendConfig BackendConfig, dr *dnsx.DomainResolver, log *log.Logger) (*Backend, error) {
if backendConfig.Address == "" {
return nil, errors.New("Address rquired")
}
if backendConfig.ActiveAfter == 0 {
backendConfig.ActiveAfter = 2
}
if backendConfig.InactiveAfter == 0 {
backendConfig.InactiveAfter = 3
}
if backendConfig.Weight == 0 {
backendConfig.Weight = 1
}
if backendConfig.Timeout == 0 {
backendConfig.Timeout = time.Millisecond * 1500
}
if backendConfig.RetryTime == 0 {
backendConfig.RetryTime = time.Millisecond * 2000
}
return &Backend{
dr: dr,
log: log,
BackendConfig: backendConfig,
BackendControl: BackendControl{
Failed: true,
Active: false,
InactiveTries: 0,
ActiveTries: 0,
Connections: 0,
ConnectUsedMillisecond: 0,
isStop: false,
},
}, nil
}
func (b *Backend) StopHeartCheck() {
b.isStop = true
}
func (b *Backend) IncreasConns() {
b.RWMutex.Lock()
b.Connections++
b.RWMutex.Unlock()
}
func (b *Backend) DecreaseConns() {
b.RWMutex.Lock()
b.Connections--
b.RWMutex.Unlock()
}
func (b *Backend) StartHeartCheck() {
if b.IsMuxCheck {
b.startMuxHeartCheck()
} else {
b.startTCPHeartCheck()
}
}
func (b *Backend) startMuxHeartCheck() {
go func() {
for {
if b.isStop {
return
}
var c net.Conn
var err error
start := time.Now().UnixNano() / int64(time.Microsecond)
c, err = b.getConn()
b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
if err != nil {
b.Active = false
time.Sleep(time.Second * 2)
continue
} else {
b.Active = true
}
for {
buf := make([]byte, 1)
c.Read(buf)
buf = nil
break
}
b.Active = false
}
}()
}
// Monitoring the backend
func (b *Backend) startTCPHeartCheck() {
go func() {
for {
if b.isStop {
return
}
var c net.Conn
var err error
start := time.Now().UnixNano() / int64(time.Microsecond)
c, err = b.getConn()
b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
if err == nil {
c.Close()
}
if err != nil {
b.RWMutex.Lock()
// Max tries before consider inactive
if b.InactiveTries >= b.InactiveAfter {
//b.log.Printf("Backend inactive [%s]", b.Address)
b.Active = false
b.ActiveTries = 0
} else {
// Ok that guy it's out of the game
b.Failed = true
b.InactiveTries++
//b.log.Printf("Error to check address [%s] tries [%d]", b.Address, b.InactiveTries)
}
b.RWMutex.Unlock()
} else {
// Ok, let's keep working boys
b.RWMutex.Lock()
if b.ActiveTries >= b.ActiveAfter {
if b.Failed {
//log.Printf("Backend active [%s]", b.Address)
}
b.Failed = false
b.Active = true
b.InactiveTries = 0
} else {
b.ActiveTries++
}
b.RWMutex.Unlock()
}
time.Sleep(b.RetryTime)
}
}()
}
func (b *Backend) getConn() (conn net.Conn, err error) {
address := b.Address
if b.dr != nil && b.dr.DnsAddress() != "" {
address, err = b.dr.Resolve(b.Address)
if err != nil {
b.log.Printf("dns error %s , ERR:%s", b.Address, err)
}
}
if b.ConnFactory != nil {
return b.ConnFactory(address, b.Timeout)
}
return net.DialTimeout("tcp", address, b.Timeout)
}

687
utils/lb/lb.go Normal file
View File

@ -0,0 +1,687 @@
package lb
import (
"crypto/md5"
"log"
"net"
"sync"
"bitbucket.org/snail/proxy/utils/dnsx"
)
const (
SELECT_ROUNDROBIN = iota
SELECT_LEASTCONN
SELECT_HASH
SELECT_WEITHT
SELECT_LEASTTIME
)
type Selector interface {
Select(srcAddr string) (addr string)
SelectBackend(srcAddr string) (b *Backend)
IncreasConns(addr string)
DecreaseConns(addr string)
Stop()
Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger)
IsActive() bool
ActiveCount() (count int)
Backends() (bs []*Backend)
}
type Group struct {
selector *Selector
log *log.Logger
dr *dnsx.DomainResolver
lock *sync.Mutex
last *Backend
debug bool
}
func NewGroup(selectType int, configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger, debug bool) Group {
bks := []*Backend{}
for _, c := range configs {
b, _ := NewBackend(*c, dr, log)
bks = append(bks, b)
}
if len(bks) > 1 {
for _, b := range bks {
b.StartHeartCheck()
}
}
var s Selector
switch selectType {
case SELECT_ROUNDROBIN:
s = NewRoundRobin(bks, log, debug)
case SELECT_LEASTCONN:
s = NewLeastConn(bks, log, debug)
case SELECT_HASH:
s = NewHash(bks, log, debug)
case SELECT_WEITHT:
s = NewWeight(bks, log, debug)
case SELECT_LEASTTIME:
s = NewLeastTime(bks, log, debug)
}
return Group{
selector: &s,
log: log,
dr: dr,
lock: &sync.Mutex{},
debug: debug,
}
}
func (g *Group) Select(srcAddr string, onlyHa bool) (addr string) {
if onlyHa {
g.lock.Lock()
defer g.lock.Unlock()
if g.last != nil && (g.last.Active || g.last.ConnectUsedMillisecond == 0) {
if g.debug {
g.log.Printf("############ choosed %s from lastest ############", g.last.Address)
printDebug(true, g.log, nil, srcAddr, (*g.selector).Backends())
}
return g.last.Address
}
g.last = (*g.selector).SelectBackend(srcAddr)
if !g.last.Active && g.last.ConnectUsedMillisecond > 0 {
g.log.Printf("###warn### lb selected empty , return default , for : %s", srcAddr)
}
return g.last.Address
}
b := (*g.selector).SelectBackend(srcAddr)
return b.Address
}
func (g *Group) IncreasConns(addr string) {
(*g.selector).IncreasConns(addr)
}
func (g *Group) DecreaseConns(addr string) {
(*g.selector).DecreaseConns(addr)
}
func (g *Group) Stop() {
if g.selector != nil {
(*g.selector).Stop()
}
}
func (g *Group) IsActive() bool {
return (*g.selector).IsActive()
}
func (g *Group) ActiveCount() (count int) {
return (*g.selector).ActiveCount()
}
func (g *Group) Reset(addrs []string) {
bks := (*g.selector).Backends()
if len(bks) == 0 {
return
}
cfg := bks[0].BackendConfig
configs := BackendsConfig{}
for _, addr := range addrs {
c := cfg
c.Address = addr
configs = append(configs, &c)
}
(*g.selector).Reset(configs, g.dr, g.log)
}
func (g *Group) Backends() []*Backend {
return (*g.selector).Backends()
}
//########################RoundRobin##########################
type RoundRobin struct {
sync.Mutex
backendIndex int
backends Backends
log *log.Logger
debug bool
}
func NewRoundRobin(backends Backends, log *log.Logger, debug bool) Selector {
return &RoundRobin{
backends: backends,
log: log,
debug: debug,
}
}
func (r *RoundRobin) Select(srcAddr string) (addr string) {
return r.SelectBackend(srcAddr).Address
}
func (r *RoundRobin) SelectBackend(srcAddr string) (b *Backend) {
r.Lock()
defer r.Unlock()
defer func() {
printDebug(r.debug, r.log, b, srcAddr, r.backends)
}()
if len(r.backends) == 0 {
return
}
if len(r.backends) == 1 {
return r.backends[0]
}
RETRY:
found := false
for _, b := range r.backends {
if b.Active {
found = true
break
}
}
if !found {
return r.backends[0]
}
r.backendIndex++
if r.backendIndex > len(r.backends)-1 {
r.backendIndex = 0
}
if !r.backends[r.backendIndex].Active {
goto RETRY
}
return r.backends[r.backendIndex]
}
func (r *RoundRobin) IncreasConns(addr string) {
}
func (r *RoundRobin) DecreaseConns(addr string) {
}
func (r *RoundRobin) Stop() {
for _, b := range r.backends {
b.StopHeartCheck()
}
}
func (r *RoundRobin) Backends() []*Backend {
return r.backends
}
func (r *RoundRobin) IsActive() bool {
for _, b := range r.backends {
if b.Active {
return true
}
}
return false
}
func (r *RoundRobin) ActiveCount() (count int) {
for _, b := range r.backends {
if b.Active {
count++
}
}
return
}
func (r *RoundRobin) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
r.Lock()
defer r.Unlock()
r.Stop()
bks := []*Backend{}
for _, c := range configs {
b, _ := NewBackend(*c, dr, log)
bks = append(bks, b)
}
if len(bks) > 1 {
for _, b := range bks {
b.StartHeartCheck()
}
}
r.backends = bks
}
//########################LeastConn##########################
type LeastConn struct {
sync.Mutex
backends Backends
log *log.Logger
debug bool
}
func NewLeastConn(backends []*Backend, log *log.Logger, debug bool) Selector {
lc := LeastConn{
backends: backends,
log: log,
debug: debug,
}
return &lc
}
func (lc *LeastConn) Select(srcAddr string) (addr string) {
return lc.SelectBackend(srcAddr).Address
}
func (lc *LeastConn) SelectBackend(srcAddr string) (b *Backend) {
lc.Lock()
defer lc.Unlock()
defer func() {
printDebug(lc.debug, lc.log, b, srcAddr, lc.backends)
}()
if len(lc.backends) == 0 {
return
}
if len(lc.backends) == 1 {
return lc.backends[0]
}
found := false
for _, b := range lc.backends {
if b.Active {
found = true
break
}
}
if !found {
return lc.backends[0]
}
min := lc.backends[0].Connections
index := 0
for i, b := range lc.backends {
if b.Active {
min = b.Connections
index = i
break
}
}
for i, b := range lc.backends {
if b.Active && b.Connections <= min {
min = b.Connections
index = i
}
}
return lc.backends[index]
}
func (lc *LeastConn) IncreasConns(addr string) {
for _, a := range lc.backends {
if a.Address == addr {
a.IncreasConns()
return
}
}
}
func (lc *LeastConn) DecreaseConns(addr string) {
for _, a := range lc.backends {
if a.Address == addr {
a.DecreaseConns()
return
}
}
}
func (lc *LeastConn) Stop() {
for _, b := range lc.backends {
b.StopHeartCheck()
}
}
func (lc *LeastConn) IsActive() bool {
for _, b := range lc.backends {
if b.Active {
return true
}
}
return false
}
func (lc *LeastConn) ActiveCount() (count int) {
for _, b := range lc.backends {
if b.Active {
count++
}
}
return
}
func (lc *LeastConn) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
lc.Lock()
defer lc.Unlock()
lc.Stop()
bks := []*Backend{}
for _, c := range configs {
b, _ := NewBackend(*c, dr, log)
bks = append(bks, b)
}
if len(bks) > 1 {
for _, b := range bks {
b.StartHeartCheck()
}
}
lc.backends = bks
}
func (lc *LeastConn) Backends() []*Backend {
return lc.backends
}
//########################Hash##########################
type Hash struct {
sync.Mutex
backends Backends
log *log.Logger
debug bool
}
func NewHash(backends Backends, log *log.Logger, debug bool) Selector {
return &Hash{
backends: backends,
log: log,
debug: debug,
}
}
func (h *Hash) Select(srcAddr string) (addr string) {
return h.SelectBackend(srcAddr).Address
}
func (h *Hash) SelectBackend(srcAddr string) (b *Backend) {
h.Lock()
defer h.Unlock()
defer func() {
printDebug(h.debug, h.log, b, srcAddr, h.backends)
}()
if len(h.backends) == 0 {
return
}
if len(h.backends) == 1 {
return h.backends[0]
}
i := 0
host, _, err := net.SplitHostPort(srcAddr)
if err != nil {
return
}
//porti, _ := strconv.Atoi(port)
//i += porti
for _, b := range md5.Sum([]byte(host)) {
i += int(b)
}
RETRY:
found := false
for _, b := range h.backends {
if b.Active {
found = true
break
}
}
if !found {
return h.backends[0]
}
k := i % len(h.backends)
if !h.backends[k].Active {
i++
goto RETRY
}
return h.backends[k]
}
func (h *Hash) IncreasConns(addr string) {
}
func (h *Hash) DecreaseConns(addr string) {
}
func (h *Hash) Stop() {
for _, b := range h.backends {
b.StopHeartCheck()
}
}
func (h *Hash) IsActive() bool {
for _, b := range h.backends {
if b.Active {
return true
}
}
return false
}
func (h *Hash) ActiveCount() (count int) {
for _, b := range h.backends {
if b.Active {
count++
}
}
return
}
func (h *Hash) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
h.Lock()
defer h.Unlock()
h.Stop()
bks := []*Backend{}
for _, c := range configs {
b, _ := NewBackend(*c, dr, log)
bks = append(bks, b)
}
if len(bks) > 1 {
for _, b := range bks {
b.StartHeartCheck()
}
}
h.backends = bks
}
func (h *Hash) Backends() []*Backend {
return h.backends
}
//########################Weight##########################
type Weight struct {
sync.Mutex
backends Backends
log *log.Logger
debug bool
}
func NewWeight(backends Backends, log *log.Logger, debug bool) Selector {
return &Weight{
backends: backends,
log: log,
debug: debug,
}
}
func (w *Weight) Select(srcAddr string) (addr string) {
return w.SelectBackend(srcAddr).Address
}
func (w *Weight) SelectBackend(srcAddr string) (b *Backend) {
w.Lock()
defer w.Unlock()
defer func() {
printDebug(w.debug, w.log, b, srcAddr, w.backends)
}()
if len(w.backends) == 0 {
return
}
if len(w.backends) == 1 {
return w.backends[0]
}
found := false
for _, b := range w.backends {
if b.Active {
found = true
break
}
}
if !found {
return w.backends[0]
}
min := w.backends[0].Connections / w.backends[0].Weight
index := 0
for i, b := range w.backends {
if b.Active {
min = b.Connections / b.Weight
index = i
break
}
}
for i, b := range w.backends {
if b.Active && b.Connections/b.Weight <= min {
min = b.Connections
index = i
}
}
return w.backends[index]
}
func (w *Weight) IncreasConns(addr string) {
w.Lock()
defer w.Unlock()
for _, a := range w.backends {
if a.Address == addr {
a.IncreasConns()
return
}
}
}
func (w *Weight) DecreaseConns(addr string) {
w.Lock()
defer w.Unlock()
for _, a := range w.backends {
if a.Address == addr {
a.DecreaseConns()
return
}
}
}
func (w *Weight) Stop() {
for _, b := range w.backends {
b.StopHeartCheck()
}
}
func (w *Weight) IsActive() bool {
for _, b := range w.backends {
if b.Active {
return true
}
}
return false
}
func (w *Weight) ActiveCount() (count int) {
for _, b := range w.backends {
if b.Active {
count++
}
}
return
}
func (w *Weight) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
w.Lock()
defer w.Unlock()
w.Stop()
bks := []*Backend{}
for _, c := range configs {
b, _ := NewBackend(*c, dr, log)
bks = append(bks, b)
}
if len(bks) > 1 {
for _, b := range bks {
b.StartHeartCheck()
}
}
w.backends = bks
}
func (w *Weight) Backends() []*Backend {
return w.backends
}
//########################LeastTime##########################
type LeastTime struct {
sync.Mutex
backends Backends
log *log.Logger
debug bool
}
func NewLeastTime(backends []*Backend, log *log.Logger, debug bool) Selector {
lt := LeastTime{
backends: backends,
log: log,
debug: debug,
}
return &lt
}
func (lt *LeastTime) Select(srcAddr string) (addr string) {
return lt.SelectBackend(srcAddr).Address
}
func (lt *LeastTime) SelectBackend(srcAddr string) (b *Backend) {
lt.Lock()
defer lt.Unlock()
defer func() {
printDebug(lt.debug, lt.log, b, srcAddr, lt.backends)
}()
if len(lt.backends) == 0 {
return
}
if len(lt.backends) == 1 {
return lt.backends[0]
}
found := false
for _, b := range lt.backends {
if b.Active {
found = true
break
}
}
if !found {
return lt.backends[0]
}
min := lt.backends[0].ConnectUsedMillisecond
index := 0
for i, b := range lt.backends {
if b.Active {
min = b.ConnectUsedMillisecond
index = i
break
}
}
for i, b := range lt.backends {
if b.Active && b.ConnectUsedMillisecond > 0 && b.ConnectUsedMillisecond <= min {
min = b.ConnectUsedMillisecond
index = i
}
}
return lt.backends[index]
}
func (lt *LeastTime) IncreasConns(addr string) {
}
func (lt *LeastTime) DecreaseConns(addr string) {
}
func (lt *LeastTime) Stop() {
for _, b := range lt.backends {
b.StopHeartCheck()
}
}
func (lt *LeastTime) IsActive() bool {
for _, b := range lt.backends {
if b.Active {
return true
}
}
return false
}
func (lt *LeastTime) ActiveCount() (count int) {
for _, b := range lt.backends {
if b.Active {
count++
}
}
return
}
func (lt *LeastTime) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
lt.Lock()
defer lt.Unlock()
lt.Stop()
bks := []*Backend{}
for _, c := range configs {
b, _ := NewBackend(*c, dr, log)
bks = append(bks, b)
}
if len(bks) > 1 {
for _, b := range bks {
b.StartHeartCheck()
}
}
lt.backends = bks
}
func (lt *LeastTime) Backends() []*Backend {
return lt.backends
}
func printDebug(isDebug bool, log *log.Logger, selected *Backend, srcAddr string, backends []*Backend) {
if isDebug {
log.Printf("############ LB start ############\n")
if selected != nil {
log.Printf("choosed %s for %s\n", selected.Address, srcAddr)
}
for _, v := range backends {
log.Printf("addr:%s,conns:%d,time:%d,weight:%d,active:%v\n", v.Address, v.Connections, v.ConnectUsedMillisecond, v.Weight, v.Active)
}
log.Printf("############ LB end ############\n")
}
}

View File

@ -1,4 +1,4 @@
package utils package mapx
import ( import (
"encoding/json" "encoding/json"

193
utils/ss/conn.go Normal file
View File

@ -0,0 +1,193 @@
package ss
import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
)
const (
OneTimeAuthMask byte = 0x10
AddrMask byte = 0xf
)
type Conn struct {
net.Conn
*Cipher
readBuf []byte
writeBuf []byte
chunkId uint32
}
func NewConn(c net.Conn, cipher *Cipher) *Conn {
return &Conn{
Conn: c,
Cipher: cipher,
readBuf: leakyBuf.Get(),
writeBuf: leakyBuf.Get()}
}
func (c *Conn) Close() error {
leakyBuf.Put(c.readBuf)
leakyBuf.Put(c.writeBuf)
return c.Conn.Close()
}
func RawAddr(addr string) (buf []byte, err error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("ss: address error %s %v", addr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("ss: invalid port %s", addr)
}
hostLen := len(host)
l := 1 + 1 + hostLen + 2 // addrType + lenByte + address + port
buf = make([]byte, l)
buf[0] = 3 // 3 means the address is domain name
buf[1] = byte(hostLen) // host address length followed by host address
copy(buf[2:], host)
binary.BigEndian.PutUint16(buf[2+hostLen:2+hostLen+2], uint16(port))
return
}
// This is intended for use by users implementing a local socks proxy.
// rawaddr shoud contain part of the data in socks request, starting from the
// ATYP field. (Refer to rfc1928 for more information.)
func DialWithRawAddr(rawConn *net.Conn, rawaddr []byte, server string, cipher *Cipher) (c *Conn, err error) {
var conn net.Conn
if rawConn == nil {
conn, err = net.Dial("tcp", server)
}
if err != nil {
return
}
if rawConn != nil {
c = NewConn(*rawConn, cipher)
} else {
c = NewConn(conn, cipher)
}
if cipher.ota {
if c.enc == nil {
if _, err = c.initEncrypt(); err != nil {
return
}
}
// since we have initEncrypt, we must send iv manually
conn.Write(cipher.iv)
rawaddr[0] |= OneTimeAuthMask
rawaddr = otaConnectAuth(cipher.iv, cipher.key, rawaddr)
}
if _, err = c.write(rawaddr); err != nil {
c.Close()
return nil, err
}
return
}
// addr should be in the form of host:port
func Dial(addr, server string, cipher *Cipher) (c *Conn, err error) {
ra, err := RawAddr(addr)
if err != nil {
return
}
return DialWithRawAddr(nil, ra, server, cipher)
}
func (c *Conn) GetIv() (iv []byte) {
iv = make([]byte, len(c.iv))
copy(iv, c.iv)
return
}
func (c *Conn) GetKey() (key []byte) {
key = make([]byte, len(c.key))
copy(key, c.key)
return
}
func (c *Conn) IsOta() bool {
return c.ota
}
func (c *Conn) GetAndIncrChunkId() (chunkId uint32) {
chunkId = c.chunkId
c.chunkId += 1
return
}
func (c *Conn) Read(b []byte) (n int, err error) {
if c.dec == nil {
iv := make([]byte, c.info.ivLen)
if _, err = io.ReadFull(c.Conn, iv); err != nil {
return
}
if err = c.initDecrypt(iv); err != nil {
return
}
if len(c.iv) == 0 {
c.iv = iv
}
}
cipherData := c.readBuf
if len(b) > len(cipherData) {
cipherData = make([]byte, len(b))
} else {
cipherData = cipherData[:len(b)]
}
n, err = c.Conn.Read(cipherData)
if n > 0 {
c.decrypt(b[0:n], cipherData[0:n])
}
return
}
func (c *Conn) Write(b []byte) (n int, err error) {
nn := len(b)
if c.ota {
chunkId := c.GetAndIncrChunkId()
b = otaReqChunkAuth(c.iv, chunkId, b)
}
headerLen := len(b) - nn
n, err = c.write(b)
// Make sure <= 0 <= len(b), where b is the slice passed in.
if n >= headerLen {
n -= headerLen
}
return
}
func (c *Conn) write(b []byte) (n int, err error) {
var iv []byte
if c.enc == nil {
iv, err = c.initEncrypt()
if err != nil {
return
}
}
cipherData := c.writeBuf
dataSize := len(b) + len(iv)
if dataSize > len(cipherData) {
cipherData = make([]byte, dataSize)
} else {
cipherData = cipherData[:dataSize]
}
if iv != nil {
// Put initialization vector in buffer, do a single write to send both
// iv and data.
copy(cipherData, iv)
}
c.encrypt(cipherData[len(iv):], b)
n, err = c.Conn.Write(cipherData)
return
}

301
utils/ss/encrypt.go Normal file
View File

@ -0,0 +1,301 @@
package ss
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/md5"
"crypto/rand"
"crypto/rc4"
"encoding/binary"
"errors"
"io"
"strings"
"github.com/Yawning/chacha20"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/salsa20/salsa"
)
var errEmptyPassword = errors.New("empty key")
func md5sum(d []byte) []byte {
h := md5.New()
h.Write(d)
return h.Sum(nil)
}
func evpBytesToKey(password string, keyLen int) (key []byte) {
const md5Len = 16
cnt := (keyLen-1)/md5Len + 1
m := make([]byte, cnt*md5Len)
copy(m, md5sum([]byte(password)))
// Repeatedly call md5 until bytes generated is enough.
// Each call to md5 uses data: prev md5 sum + password.
d := make([]byte, md5Len+len(password))
start := 0
for i := 1; i < cnt; i++ {
start += md5Len
copy(d, m[start-md5Len:start])
copy(d[md5Len:], password)
copy(m[start:], md5sum(d))
}
return m[:keyLen]
}
type DecOrEnc int
const (
Decrypt DecOrEnc = iota
Encrypt
)
func newStream(block cipher.Block, err error, key, iv []byte,
doe DecOrEnc) (cipher.Stream, error) {
if err != nil {
return nil, err
}
if doe == Encrypt {
return cipher.NewCFBEncrypter(block, iv), nil
} else {
return cipher.NewCFBDecrypter(block, iv), nil
}
}
func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
block, err := aes.NewCipher(key)
return newStream(block, err, key, iv, doe)
}
func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewCTR(block, iv), nil
}
func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
block, err := des.NewCipher(key)
return newStream(block, err, key, iv, doe)
}
func newBlowFishStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
block, err := blowfish.NewCipher(key)
return newStream(block, err, key, iv, doe)
}
func newCast5Stream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
block, err := cast5.NewCipher(key)
return newStream(block, err, key, iv, doe)
}
func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
h := md5.New()
h.Write(key)
h.Write(iv)
rc4key := h.Sum(nil)
return rc4.NewCipher(rc4key)
}
func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
return chacha20.NewCipher(key, iv)
}
func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
return chacha20.NewCipher(key, iv)
}
type salsaStreamCipher struct {
nonce [8]byte
key [32]byte
counter int
}
func (c *salsaStreamCipher) XORKeyStream(dst, src []byte) {
var buf []byte
padLen := c.counter % 64
dataSize := len(src) + padLen
if cap(dst) >= dataSize {
buf = dst[:dataSize]
} else if leakyBufSize >= dataSize {
buf = leakyBuf.Get()
defer leakyBuf.Put(buf)
buf = buf[:dataSize]
} else {
buf = make([]byte, dataSize)
}
var subNonce [16]byte
copy(subNonce[:], c.nonce[:])
binary.LittleEndian.PutUint64(subNonce[len(c.nonce):], uint64(c.counter/64))
// It's difficult to avoid data copy here. src or dst maybe slice from
// Conn.Read/Write, which can't have padding.
copy(buf[padLen:], src[:])
salsa.XORKeyStream(buf, buf, &subNonce, &c.key)
copy(dst, buf[padLen:])
c.counter += len(src)
}
func newSalsa20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
var c salsaStreamCipher
copy(c.nonce[:], iv[:8])
copy(c.key[:], key[:32])
return &c, nil
}
type cipherInfo struct {
keyLen int
ivLen int
newStream func(key, iv []byte, doe DecOrEnc) (cipher.Stream, error)
}
var cipherMethod = map[string]*cipherInfo{
"aes-128-cfb": {16, 16, newAESCFBStream},
"aes-192-cfb": {24, 16, newAESCFBStream},
"aes-256-cfb": {32, 16, newAESCFBStream},
"aes-128-ctr": {16, 16, newAESCTRStream},
"aes-192-ctr": {24, 16, newAESCTRStream},
"aes-256-ctr": {32, 16, newAESCTRStream},
"des-cfb": {8, 8, newDESStream},
"bf-cfb": {16, 8, newBlowFishStream},
"cast5-cfb": {16, 8, newCast5Stream},
"rc4-md5": {16, 16, newRC4MD5Stream},
"rc4-md5-6": {16, 6, newRC4MD5Stream},
"chacha20": {32, 8, newChaCha20Stream},
"chacha20-ietf": {32, 12, newChaCha20IETFStream},
"salsa20": {32, 8, newSalsa20Stream},
}
func CheckCipherMethod(method string) error {
if method == "" {
method = "aes-256-cfb"
}
_, ok := cipherMethod[method]
if !ok {
return errors.New("Unsupported encryption method: " + method)
}
return nil
}
type Cipher struct {
enc cipher.Stream
dec cipher.Stream
key []byte
info *cipherInfo
ota bool // one-time auth
iv []byte
}
// NewCipher creates a cipher that can be used in Dial() etc.
// Use cipher.Copy() to create a new cipher with the same method and password
// to avoid the cost of repeated cipher initialization.
func NewCipher(method, password string) (c *Cipher, err error) {
if password == "" {
return nil, errEmptyPassword
}
var ota bool
if strings.HasSuffix(strings.ToLower(method), "-auth") {
method = method[:len(method)-5] // len("-auth") = 5
ota = true
} else {
ota = false
}
mi, ok := cipherMethod[method]
if !ok {
return nil, errors.New("Unsupported encryption method: " + method)
}
key := evpBytesToKey(password, mi.keyLen)
c = &Cipher{key: key, info: mi}
if err != nil {
return nil, err
}
c.ota = ota
return c, nil
}
// Initializes the block cipher with CFB mode, returns IV.
func (c *Cipher) initEncrypt() (iv []byte, err error) {
if c.iv == nil {
iv = make([]byte, c.info.ivLen)
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
c.iv = iv
} else {
iv = c.iv
}
c.enc, err = c.info.newStream(c.key, iv, Encrypt)
return
}
func (c *Cipher) initDecrypt(iv []byte) (err error) {
c.dec, err = c.info.newStream(c.key, iv, Decrypt)
return
}
func (c *Cipher) encrypt(dst, src []byte) {
c.enc.XORKeyStream(dst, src)
}
func (c *Cipher) decrypt(dst, src []byte) {
c.dec.XORKeyStream(dst, src)
}
func (c *Cipher) Encrypt(src []byte) (cipherData []byte) {
cipher := c.Copy()
iv, err := cipher.initEncrypt()
if err != nil {
return
}
packetLen := len(src) + len(iv)
cipherData = make([]byte, packetLen)
copy(cipherData, iv)
cipher.encrypt(cipherData[len(iv):], src)
return
}
func (c *Cipher) Decrypt(src []byte) (data []byte) {
cipher := c.Copy()
if len(src) < c.info.ivLen {
return
}
iv := make([]byte, c.info.ivLen)
copy(iv, src[:c.info.ivLen])
if err := cipher.initDecrypt(iv); err != nil {
return
}
data = make([]byte, len(src)-len(iv))
cipher.decrypt(data[0:], src[c.info.ivLen:])
return
}
// Copy creates a new cipher at it's initial state.
func (c *Cipher) Copy() *Cipher {
// This optimization maybe not necessary. But without this function, we
// need to maintain a table cache for newTableCipher and use lock to
// protect concurrent access to that cache.
// AES and DES ciphers does not return specific types, so it's difficult
// to create copy. But their initizliation time is less than 4000ns on my
// 2.26 GHz Intel Core 2 Duo processor. So no need to worry.
// Currently, blow-fish and cast5 initialization cost is an order of
// maganitude slower than other ciphers. (I'm not sure whether this is
// because the current implementation is not highly optimized, or this is
// the nature of the algorithm.)
nc := *c
nc.enc = nil
nc.dec = nil
nc.ota = c.ota
return &nc
}

131
utils/ss/util.go Normal file
View File

@ -0,0 +1,131 @@
package ss
import (
"crypto/hmac"
"crypto/sha1"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"bitbucket.org/snail/proxy/utils"
)
const leakyBufSize = 4108 // data.len(2) + hmacsha1(10) + data(4096)
const maxNBuf = 2048
var leakyBuf = utils.NewLeakyBuf(maxNBuf, leakyBufSize)
func IsFileExists(path string) (bool, error) {
stat, err := os.Stat(path)
if err == nil {
if stat.Mode()&os.ModeType == 0 {
return true, nil
}
return false, errors.New(path + " exists but is not regular file")
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
func HmacSha1(key []byte, data []byte) []byte {
hmacSha1 := hmac.New(sha1.New, key)
hmacSha1.Write(data)
return hmacSha1.Sum(nil)[:10]
}
func otaConnectAuth(iv, key, data []byte) []byte {
return append(data, HmacSha1(append(iv, key...), data)...)
}
func otaReqChunkAuth(iv []byte, chunkId uint32, data []byte) []byte {
nb := make([]byte, 2)
binary.BigEndian.PutUint16(nb, uint16(len(data)))
chunkIdBytes := make([]byte, 4)
binary.BigEndian.PutUint32(chunkIdBytes, chunkId)
header := append(nb, HmacSha1(append(iv, chunkIdBytes...), data)...)
return append(header, data...)
}
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index
typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address
lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
lenDmBase = 2 // 1addrLen + 2port, plus addrLen
lenHmacSha1 = 10
)
func GetRequest(conn *Conn) (host string, err error) {
// buf size should at least have the same size with the largest possible
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 255(max length address) + 2(port) + 10(hmac-sha1)
buf := make([]byte, 269)
// read till we get possible domain length field
if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
return
}
var reqStart, reqEnd int
addrType := buf[idType]
switch addrType & AddrMask {
case typeIPv4:
reqStart, reqEnd = idIP0, idIP0+lenIPv4
case typeIPv6:
reqStart, reqEnd = idIP0, idIP0+lenIPv6
case typeDm:
if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
return
}
reqStart, reqEnd = idDm0, idDm0+int(buf[idDmLen])+lenDmBase
default:
err = fmt.Errorf("addr type %d not supported", addrType&AddrMask)
return
}
if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
return
}
// Return string for typeIP is not most efficient, but browsers (Chrome,
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
// big problem.
switch addrType & AddrMask {
case typeIPv4:
host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
case typeIPv6:
host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String()
case typeDm:
host = string(buf[idDm0 : idDm0+int(buf[idDmLen])])
}
// parse port
port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
host = net.JoinHostPort(host, strconv.Itoa(int(port)))
return
}
type ClosedFlag struct {
flag bool
}
func (flag *ClosedFlag) SetClosed() {
flag.flag = true
}
func (flag *ClosedFlag) IsClosed() bool {
return flag.flag
}