merge enterprise
This commit is contained in:
161
services/sps/ssudp.go
Normal file
161
services/sps/ssudp.go
Normal file
@ -0,0 +1,161 @@
|
||||
package sps
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"bitbucket.org/snail/proxy/utils"
|
||||
goaes "bitbucket.org/snail/proxy/utils/aes"
|
||||
"bitbucket.org/snail/proxy/utils/socks"
|
||||
)
|
||||
|
||||
func (s *SPS) RunSSUDP(addr string) (err error) {
|
||||
a, _ := net.ResolveUDPAddr("udp", addr)
|
||||
listener, err := net.ListenUDP("udp", a)
|
||||
if err != nil {
|
||||
s.log.Printf("ss udp bind error %s", err)
|
||||
return
|
||||
}
|
||||
s.log.Printf("ss udp on %s", listener.LocalAddr())
|
||||
s.udpRelatedPacketConns.Set(addr, listener)
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("udp local->out io copy crashed:\n%s\n%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
buf := utils.LeakyBuffer.Get()
|
||||
defer utils.LeakyBuffer.Put(buf)
|
||||
n, srcAddr, err := listener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
s.log.Printf("read from client error %s", err)
|
||||
if utils.IsNetClosedErr(err) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
var (
|
||||
inconnRemoteAddr = srcAddr.String()
|
||||
outUDPConn *net.UDPConn
|
||||
outconn net.Conn
|
||||
outconnLocalAddr string
|
||||
destAddr *net.UDPAddr
|
||||
clean = func(msg, err string) {
|
||||
raddr := ""
|
||||
if outUDPConn != nil {
|
||||
raddr = outUDPConn.RemoteAddr().String()
|
||||
outUDPConn.Close()
|
||||
}
|
||||
if msg != "" {
|
||||
if raddr != "" {
|
||||
s.log.Printf("%s , %s , %s -> %s", msg, err, inconnRemoteAddr, raddr)
|
||||
} else {
|
||||
s.log.Printf("%s , %s , from : %s", msg, err, inconnRemoteAddr)
|
||||
}
|
||||
}
|
||||
s.userConns.Remove(inconnRemoteAddr)
|
||||
if outconn != nil {
|
||||
outconn.Close()
|
||||
}
|
||||
if outconnLocalAddr != "" {
|
||||
s.userConns.Remove(outconnLocalAddr)
|
||||
}
|
||||
}
|
||||
)
|
||||
defer clean("", "")
|
||||
|
||||
raw := new(bytes.Buffer)
|
||||
raw.Write([]byte{0x00, 0x00, 0x00})
|
||||
raw.Write(s.localCipher.Decrypt(buf[:n]))
|
||||
socksPacket := socks.NewPacketUDP()
|
||||
err = socksPacket.Parse(raw.Bytes())
|
||||
raw = nil
|
||||
if err != nil {
|
||||
s.log.Printf("udp parse error %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := s.udpRelatedPacketConns.Get(inconnRemoteAddr); !ok {
|
||||
//socks client
|
||||
lbAddr := s.lb.Select(inconnRemoteAddr, *s.cfg.LoadBalanceOnlyHA)
|
||||
outconn, err := s.GetParentConn(lbAddr)
|
||||
if err != nil {
|
||||
clean("connnect fail", fmt.Sprintf("%s", err))
|
||||
return
|
||||
}
|
||||
|
||||
client, err := s.HandshakeSocksParent(&outconn, "udp", socksPacket.Addr(), socks.Auth{}, true)
|
||||
if err != nil {
|
||||
clean("handshake fail", fmt.Sprintf("%s", err))
|
||||
return
|
||||
}
|
||||
|
||||
outconnLocalAddr = outconn.LocalAddr().String()
|
||||
s.userConns.Set(outconnLocalAddr, &outconn)
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
s.log.Printf("udp related parent tcp conn read crashed:\n%s\n%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1)
|
||||
outconn.SetReadDeadline(time.Time{})
|
||||
if _, err := outconn.Read(buf); err != nil {
|
||||
clean("udp parent tcp conn disconnected", fmt.Sprintf("%s", err))
|
||||
}
|
||||
}()
|
||||
destAddr, _ = net.ResolveUDPAddr("udp", client.UDPAddr)
|
||||
localZeroAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
|
||||
outUDPConn, err = net.DialUDP("udp", localZeroAddr, destAddr)
|
||||
if err != nil {
|
||||
s.log.Printf("create out udp conn fail , %s , from : %s", err, srcAddr)
|
||||
return
|
||||
}
|
||||
s.udpRelatedPacketConns.Set(srcAddr.String(), outUDPConn)
|
||||
utils.UDPCopy(listener, outUDPConn, srcAddr, time.Second*5, func(data []byte) []byte {
|
||||
//forward to local
|
||||
var v []byte
|
||||
//convert parent data to raw
|
||||
if len(s.udpParentKey) > 0 {
|
||||
v, err = goaes.Decrypt(s.udpParentKey, data)
|
||||
if err != nil {
|
||||
s.log.Printf("udp outconn parse packet fail, %s", err.Error())
|
||||
return []byte{}
|
||||
}
|
||||
} else {
|
||||
v = data
|
||||
}
|
||||
return s.localCipher.Encrypt(v[3:])
|
||||
}, func(err interface{}) {
|
||||
s.udpRelatedPacketConns.Remove(srcAddr.String())
|
||||
if err != nil {
|
||||
s.log.Printf("udp out->local io copy crashed:\n%s\n%s", err, string(debug.Stack()))
|
||||
}
|
||||
})
|
||||
} else {
|
||||
outUDPConn = v.(*net.UDPConn)
|
||||
}
|
||||
//forward to parent
|
||||
//p is raw, now convert it to parent
|
||||
var v []byte
|
||||
if len(s.udpParentKey) > 0 {
|
||||
v, _ = goaes.Encrypt(s.udpParentKey, socksPacket.Bytes())
|
||||
} else {
|
||||
v = socksPacket.Bytes()
|
||||
}
|
||||
_, err = outUDPConn.Write(v)
|
||||
socksPacket = socks.PacketUDP{}
|
||||
if err != nil {
|
||||
if utils.IsNetClosedErr(err) {
|
||||
return
|
||||
}
|
||||
s.log.Printf("send out udp data fail , %s , from : %s", err, srcAddr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
235
utils/datasize/datasize.go
Normal file
235
utils/datasize/datasize.go
Normal file
@ -0,0 +1,235 @@
|
||||
package datasize
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ByteSize uint64
|
||||
|
||||
const (
|
||||
B ByteSize = 1
|
||||
KB = B << 10
|
||||
MB = KB << 10
|
||||
GB = MB << 10
|
||||
TB = GB << 10
|
||||
PB = TB << 10
|
||||
EB = PB << 10
|
||||
|
||||
fnUnmarshalText string = "UnmarshalText"
|
||||
maxUint64 uint64 = (1 << 64) - 1
|
||||
cutoff uint64 = maxUint64 / 10
|
||||
)
|
||||
|
||||
var ErrBits = errors.New("unit with capital unit prefix and lower case unit (b) - bits, not bytes ")
|
||||
|
||||
var defaultDatasize ByteSize
|
||||
|
||||
func Parse(s string) (bytes uint64, err error) {
|
||||
err = defaultDatasize.UnmarshalText([]byte(s))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
bytes = defaultDatasize.Bytes()
|
||||
return
|
||||
}
|
||||
func HumanSize(bytes uint64) (s string, err error) {
|
||||
err = defaultDatasize.UnmarshalText([]byte(fmt.Sprintf("%d", bytes)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s = defaultDatasize.HumanReadable()
|
||||
return
|
||||
}
|
||||
func (b ByteSize) Bytes() uint64 {
|
||||
return uint64(b)
|
||||
}
|
||||
|
||||
func (b ByteSize) KBytes() float64 {
|
||||
v := b / KB
|
||||
r := b % KB
|
||||
return float64(v) + float64(r)/float64(KB)
|
||||
}
|
||||
|
||||
func (b ByteSize) MBytes() float64 {
|
||||
v := b / MB
|
||||
r := b % MB
|
||||
return float64(v) + float64(r)/float64(MB)
|
||||
}
|
||||
|
||||
func (b ByteSize) GBytes() float64 {
|
||||
v := b / GB
|
||||
r := b % GB
|
||||
return float64(v) + float64(r)/float64(GB)
|
||||
}
|
||||
|
||||
func (b ByteSize) TBytes() float64 {
|
||||
v := b / TB
|
||||
r := b % TB
|
||||
return float64(v) + float64(r)/float64(TB)
|
||||
}
|
||||
|
||||
func (b ByteSize) PBytes() float64 {
|
||||
v := b / PB
|
||||
r := b % PB
|
||||
return float64(v) + float64(r)/float64(PB)
|
||||
}
|
||||
|
||||
func (b ByteSize) EBytes() float64 {
|
||||
v := b / EB
|
||||
r := b % EB
|
||||
return float64(v) + float64(r)/float64(EB)
|
||||
}
|
||||
|
||||
func (b ByteSize) String() string {
|
||||
switch {
|
||||
case b == 0:
|
||||
return fmt.Sprint("0B")
|
||||
case b%EB == 0:
|
||||
return fmt.Sprintf("%dEB", b/EB)
|
||||
case b%PB == 0:
|
||||
return fmt.Sprintf("%dPB", b/PB)
|
||||
case b%TB == 0:
|
||||
return fmt.Sprintf("%dTB", b/TB)
|
||||
case b%GB == 0:
|
||||
return fmt.Sprintf("%dGB", b/GB)
|
||||
case b%MB == 0:
|
||||
return fmt.Sprintf("%dMB", b/MB)
|
||||
case b%KB == 0:
|
||||
return fmt.Sprintf("%dKB", b/KB)
|
||||
default:
|
||||
return fmt.Sprintf("%dB", b)
|
||||
}
|
||||
}
|
||||
|
||||
func (b ByteSize) HR() string {
|
||||
return b.HumanReadable()
|
||||
}
|
||||
|
||||
func (b ByteSize) HumanReadable() string {
|
||||
switch {
|
||||
case b > EB:
|
||||
return fmt.Sprintf("%.1f EB", b.EBytes())
|
||||
case b > PB:
|
||||
return fmt.Sprintf("%.1f PB", b.PBytes())
|
||||
case b > TB:
|
||||
return fmt.Sprintf("%.1f TB", b.TBytes())
|
||||
case b > GB:
|
||||
return fmt.Sprintf("%.1f GB", b.GBytes())
|
||||
case b > MB:
|
||||
return fmt.Sprintf("%.1f MB", b.MBytes())
|
||||
case b > KB:
|
||||
return fmt.Sprintf("%.1f KB", b.KBytes())
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
}
|
||||
|
||||
func (b ByteSize) MarshalText() ([]byte, error) {
|
||||
return []byte(b.String()), nil
|
||||
}
|
||||
|
||||
func (b *ByteSize) UnmarshalText(t []byte) error {
|
||||
var val uint64
|
||||
var unit string
|
||||
|
||||
// copy for error message
|
||||
t0 := t
|
||||
|
||||
var c byte
|
||||
var i int
|
||||
|
||||
ParseLoop:
|
||||
for i < len(t) {
|
||||
c = t[i]
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
if val > cutoff {
|
||||
goto Overflow
|
||||
}
|
||||
|
||||
c = c - '0'
|
||||
val *= 10
|
||||
|
||||
if val > val+uint64(c) {
|
||||
// val+v overflows
|
||||
goto Overflow
|
||||
}
|
||||
val += uint64(c)
|
||||
i++
|
||||
|
||||
default:
|
||||
if i == 0 {
|
||||
goto SyntaxError
|
||||
}
|
||||
break ParseLoop
|
||||
}
|
||||
}
|
||||
|
||||
unit = strings.TrimSpace(string(t[i:]))
|
||||
switch unit {
|
||||
case "Kb", "Mb", "Gb", "Tb", "Pb", "Eb":
|
||||
goto BitsError
|
||||
}
|
||||
unit = strings.ToLower(unit)
|
||||
switch unit {
|
||||
case "", "b", "byte":
|
||||
// do nothing - already in bytes
|
||||
|
||||
case "k", "kb", "kilo", "kilobyte", "kilobytes":
|
||||
if val > maxUint64/uint64(KB) {
|
||||
goto Overflow
|
||||
}
|
||||
val *= uint64(KB)
|
||||
|
||||
case "m", "mb", "mega", "megabyte", "megabytes":
|
||||
if val > maxUint64/uint64(MB) {
|
||||
goto Overflow
|
||||
}
|
||||
val *= uint64(MB)
|
||||
|
||||
case "g", "gb", "giga", "gigabyte", "gigabytes":
|
||||
if val > maxUint64/uint64(GB) {
|
||||
goto Overflow
|
||||
}
|
||||
val *= uint64(GB)
|
||||
|
||||
case "t", "tb", "tera", "terabyte", "terabytes":
|
||||
if val > maxUint64/uint64(TB) {
|
||||
goto Overflow
|
||||
}
|
||||
val *= uint64(TB)
|
||||
|
||||
case "p", "pb", "peta", "petabyte", "petabytes":
|
||||
if val > maxUint64/uint64(PB) {
|
||||
goto Overflow
|
||||
}
|
||||
val *= uint64(PB)
|
||||
|
||||
case "E", "EB", "e", "eb", "eB":
|
||||
if val > maxUint64/uint64(EB) {
|
||||
goto Overflow
|
||||
}
|
||||
val *= uint64(EB)
|
||||
|
||||
default:
|
||||
goto SyntaxError
|
||||
}
|
||||
|
||||
*b = ByteSize(val)
|
||||
return nil
|
||||
|
||||
Overflow:
|
||||
*b = ByteSize(maxUint64)
|
||||
return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrRange}
|
||||
|
||||
SyntaxError:
|
||||
*b = 0
|
||||
return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrSyntax}
|
||||
|
||||
BitsError:
|
||||
*b = 0
|
||||
return &strconv.NumError{fnUnmarshalText, string(t0), ErrBits}
|
||||
}
|
||||
114
utils/dnsx/resolver.go
Normal file
114
utils/dnsx/resolver.go
Normal file
@ -0,0 +1,114 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
logger "log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"bitbucket.org/snail/proxy/utils/mapx"
|
||||
dns "github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type DomainResolver struct {
|
||||
ttl int
|
||||
dnsAddrress string
|
||||
data mapx.ConcurrentMap
|
||||
log *logger.Logger
|
||||
}
|
||||
type DomainResolverItem struct {
|
||||
ip string
|
||||
domain string
|
||||
expiredAt int64
|
||||
}
|
||||
|
||||
func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver {
|
||||
return DomainResolver{
|
||||
ttl: ttl,
|
||||
dnsAddrress: dnsAddrress,
|
||||
data: mapx.NewConcurrentMap(),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
func (a *DomainResolver) DnsAddress() (address string) {
|
||||
address = a.dnsAddrress
|
||||
return
|
||||
}
|
||||
func (a *DomainResolver) MustResolve(address string) (ip string) {
|
||||
ip, _ = a.Resolve(address)
|
||||
return
|
||||
}
|
||||
func (a *DomainResolver) Resolve(address string) (ip string, err error) {
|
||||
domain := address
|
||||
port := ""
|
||||
fromCache := "false"
|
||||
defer func() {
|
||||
if port != "" {
|
||||
ip = net.JoinHostPort(ip, port)
|
||||
}
|
||||
a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache)
|
||||
//a.PrintData()
|
||||
}()
|
||||
if strings.Contains(domain, ":") {
|
||||
domain, port, err = net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if net.ParseIP(domain) != nil {
|
||||
ip = domain
|
||||
fromCache = "ip ignore"
|
||||
return
|
||||
}
|
||||
item, ok := a.data.Get(domain)
|
||||
if ok {
|
||||
//log.Println("find ", domain)
|
||||
if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
|
||||
ip = (*item.(*DomainResolverItem)).ip
|
||||
fromCache = "true"
|
||||
//log.Println("from cache ", domain)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
item = &DomainResolverItem{
|
||||
domain: domain,
|
||||
}
|
||||
|
||||
}
|
||||
c := new(dns.Client)
|
||||
c.DialTimeout = time.Millisecond * 5000
|
||||
c.ReadTimeout = time.Millisecond * 5000
|
||||
c.WriteTimeout = time.Millisecond * 5000
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
r, _, err := c.Exchange(m, a.dnsAddrress)
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
|
||||
return
|
||||
}
|
||||
for _, answer := range r.Answer {
|
||||
if answer.Header().Rrtype == dns.TypeA {
|
||||
info := strings.Fields(answer.String())
|
||||
if len(info) >= 5 {
|
||||
ip = info[4]
|
||||
_item := item.(*DomainResolverItem)
|
||||
(*_item).expiredAt = time.Now().Unix() + int64(a.ttl)
|
||||
(*_item).ip = ip
|
||||
a.data.Set(domain, item)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (a *DomainResolver) PrintData() {
|
||||
for k, item := range a.data.Items() {
|
||||
d := item.(*DomainResolverItem)
|
||||
a.log.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt)
|
||||
}
|
||||
}
|
||||
178
utils/iolimiter/iolimiter.go
Normal file
178
utils/iolimiter/iolimiter.go
Normal file
@ -0,0 +1,178 @@
|
||||
package iolimiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const burstLimit = 1000 * 1000 * 1000
|
||||
|
||||
type Reader struct {
|
||||
r io.Reader
|
||||
limiter *rate.Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
w io.Writer
|
||||
limiter *rate.Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
net.Conn
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
readLimiter *rate.Limiter
|
||||
writeLimiter *rate.Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
//NewtRateLimitConn sets rate limit (bytes/sec) to the Conn read and write.
|
||||
func NewtConn(c net.Conn, bytesPerSec float64) net.Conn {
|
||||
s := &conn{
|
||||
Conn: c,
|
||||
r: c,
|
||||
w: c,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
return s
|
||||
}
|
||||
|
||||
//NewtRateLimitReaderConn sets rate limit (bytes/sec) to the Conn read.
|
||||
func NewReaderConn(c net.Conn, bytesPerSec float64) net.Conn {
|
||||
s := &conn{
|
||||
Conn: c,
|
||||
r: c,
|
||||
w: c,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
return s
|
||||
}
|
||||
|
||||
//NewtRateLimitWriterConn sets rate limit (bytes/sec) to the Conn write.
|
||||
func NewWriterConn(c net.Conn, bytesPerSec float64) net.Conn {
|
||||
s := &conn{
|
||||
Conn: c,
|
||||
r: c,
|
||||
w: c,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
return s
|
||||
}
|
||||
|
||||
// Read reads bytes into p.
|
||||
func (s *conn) Read(p []byte) (int, error) {
|
||||
if s.readLimiter == nil {
|
||||
return s.r.Read(p)
|
||||
}
|
||||
n, err := s.r.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if err := s.readLimiter.WaitN(s.ctx, n); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write writes bytes from p.
|
||||
func (s *conn) Write(p []byte) (int, error) {
|
||||
if s.writeLimiter == nil {
|
||||
return s.w.Write(p)
|
||||
}
|
||||
n, err := s.w.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if err := s.writeLimiter.WaitN(s.ctx, n); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// NewReader returns a reader that implements io.Reader with rate limiting.
|
||||
func NewReader(r io.Reader) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewReaderWithContext returns a reader that implements io.Reader with rate limiting.
|
||||
func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriter returns a writer that implements io.Writer with rate limiting.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriterWithContext returns a writer that implements io.Writer with rate limiting.
|
||||
func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// SetRateLimit sets rate limit (bytes/sec) to the reader.
|
||||
func (s *Reader) SetRateLimit(bytesPerSec float64) {
|
||||
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
}
|
||||
|
||||
// Read reads bytes into p.
|
||||
func (s *Reader) Read(p []byte) (int, error) {
|
||||
if s.limiter == nil {
|
||||
return s.r.Read(p)
|
||||
}
|
||||
n, err := s.r.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// SetRateLimit sets rate limit (bytes/sec) to the writer.
|
||||
func (s *Writer) SetRateLimit(bytesPerSec float64) {
|
||||
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
}
|
||||
|
||||
// Write writes bytes from p.
|
||||
func (s *Writer) Write(p []byte) (int, error) {
|
||||
if s.limiter == nil {
|
||||
return s.w.Write(p)
|
||||
}
|
||||
n, err := s.w.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
203
utils/lb/backend.go
Normal file
203
utils/lb/backend.go
Normal file
@ -0,0 +1,203 @@
|
||||
package lb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"bitbucket.org/snail/proxy/utils/dnsx"
|
||||
)
|
||||
|
||||
// BackendConfig it's the configuration loaded
|
||||
type BackendConfig struct {
|
||||
Address string
|
||||
|
||||
ActiveAfter int
|
||||
InactiveAfter int
|
||||
Weight int
|
||||
|
||||
Timeout time.Duration
|
||||
RetryTime time.Duration
|
||||
|
||||
IsMuxCheck bool
|
||||
ConnFactory func(address string, timeout time.Duration) (net.Conn, error)
|
||||
}
|
||||
type BackendsConfig []*BackendConfig
|
||||
|
||||
// BackendControl keep the control data
|
||||
type BackendControl struct {
|
||||
Failed bool // The last request failed
|
||||
Active bool
|
||||
|
||||
InactiveTries int
|
||||
ActiveTries int
|
||||
|
||||
Connections int
|
||||
|
||||
ConnectUsedMillisecond int
|
||||
|
||||
isStop bool
|
||||
}
|
||||
|
||||
// Backend structure
|
||||
type Backend struct {
|
||||
BackendConfig
|
||||
BackendControl
|
||||
sync.RWMutex
|
||||
log *log.Logger
|
||||
dr *dnsx.DomainResolver
|
||||
}
|
||||
|
||||
type Backends []*Backend
|
||||
|
||||
func NewBackend(backendConfig BackendConfig, dr *dnsx.DomainResolver, log *log.Logger) (*Backend, error) {
|
||||
|
||||
if backendConfig.Address == "" {
|
||||
return nil, errors.New("Address rquired")
|
||||
}
|
||||
if backendConfig.ActiveAfter == 0 {
|
||||
backendConfig.ActiveAfter = 2
|
||||
}
|
||||
if backendConfig.InactiveAfter == 0 {
|
||||
backendConfig.InactiveAfter = 3
|
||||
}
|
||||
if backendConfig.Weight == 0 {
|
||||
backendConfig.Weight = 1
|
||||
}
|
||||
if backendConfig.Timeout == 0 {
|
||||
backendConfig.Timeout = time.Millisecond * 1500
|
||||
}
|
||||
if backendConfig.RetryTime == 0 {
|
||||
backendConfig.RetryTime = time.Millisecond * 2000
|
||||
}
|
||||
return &Backend{
|
||||
dr: dr,
|
||||
log: log,
|
||||
BackendConfig: backendConfig,
|
||||
BackendControl: BackendControl{
|
||||
Failed: true,
|
||||
Active: false,
|
||||
InactiveTries: 0,
|
||||
ActiveTries: 0,
|
||||
Connections: 0,
|
||||
ConnectUsedMillisecond: 0,
|
||||
isStop: false,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
func (b *Backend) StopHeartCheck() {
|
||||
b.isStop = true
|
||||
}
|
||||
|
||||
func (b *Backend) IncreasConns() {
|
||||
b.RWMutex.Lock()
|
||||
b.Connections++
|
||||
b.RWMutex.Unlock()
|
||||
}
|
||||
|
||||
func (b *Backend) DecreaseConns() {
|
||||
b.RWMutex.Lock()
|
||||
b.Connections--
|
||||
b.RWMutex.Unlock()
|
||||
}
|
||||
|
||||
func (b *Backend) StartHeartCheck() {
|
||||
if b.IsMuxCheck {
|
||||
b.startMuxHeartCheck()
|
||||
} else {
|
||||
b.startTCPHeartCheck()
|
||||
}
|
||||
}
|
||||
func (b *Backend) startMuxHeartCheck() {
|
||||
go func() {
|
||||
for {
|
||||
if b.isStop {
|
||||
return
|
||||
}
|
||||
var c net.Conn
|
||||
var err error
|
||||
start := time.Now().UnixNano() / int64(time.Microsecond)
|
||||
c, err = b.getConn()
|
||||
b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
|
||||
if err != nil {
|
||||
b.Active = false
|
||||
time.Sleep(time.Second * 2)
|
||||
continue
|
||||
} else {
|
||||
b.Active = true
|
||||
}
|
||||
for {
|
||||
buf := make([]byte, 1)
|
||||
c.Read(buf)
|
||||
buf = nil
|
||||
break
|
||||
}
|
||||
b.Active = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Monitoring the backend
|
||||
func (b *Backend) startTCPHeartCheck() {
|
||||
go func() {
|
||||
for {
|
||||
if b.isStop {
|
||||
return
|
||||
}
|
||||
var c net.Conn
|
||||
var err error
|
||||
start := time.Now().UnixNano() / int64(time.Microsecond)
|
||||
c, err = b.getConn()
|
||||
b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
|
||||
if err == nil {
|
||||
c.Close()
|
||||
}
|
||||
if err != nil {
|
||||
b.RWMutex.Lock()
|
||||
// Max tries before consider inactive
|
||||
if b.InactiveTries >= b.InactiveAfter {
|
||||
//b.log.Printf("Backend inactive [%s]", b.Address)
|
||||
b.Active = false
|
||||
b.ActiveTries = 0
|
||||
} else {
|
||||
// Ok that guy it's out of the game
|
||||
b.Failed = true
|
||||
b.InactiveTries++
|
||||
//b.log.Printf("Error to check address [%s] tries [%d]", b.Address, b.InactiveTries)
|
||||
}
|
||||
b.RWMutex.Unlock()
|
||||
} else {
|
||||
|
||||
// Ok, let's keep working boys
|
||||
b.RWMutex.Lock()
|
||||
if b.ActiveTries >= b.ActiveAfter {
|
||||
if b.Failed {
|
||||
//log.Printf("Backend active [%s]", b.Address)
|
||||
}
|
||||
b.Failed = false
|
||||
b.Active = true
|
||||
b.InactiveTries = 0
|
||||
} else {
|
||||
b.ActiveTries++
|
||||
}
|
||||
b.RWMutex.Unlock()
|
||||
}
|
||||
time.Sleep(b.RetryTime)
|
||||
}
|
||||
}()
|
||||
}
|
||||
func (b *Backend) getConn() (conn net.Conn, err error) {
|
||||
address := b.Address
|
||||
if b.dr != nil && b.dr.DnsAddress() != "" {
|
||||
address, err = b.dr.Resolve(b.Address)
|
||||
if err != nil {
|
||||
b.log.Printf("dns error %s , ERR:%s", b.Address, err)
|
||||
}
|
||||
}
|
||||
if b.ConnFactory != nil {
|
||||
return b.ConnFactory(address, b.Timeout)
|
||||
}
|
||||
return net.DialTimeout("tcp", address, b.Timeout)
|
||||
}
|
||||
687
utils/lb/lb.go
Normal file
687
utils/lb/lb.go
Normal file
@ -0,0 +1,687 @@
|
||||
package lb
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"bitbucket.org/snail/proxy/utils/dnsx"
|
||||
)
|
||||
|
||||
const (
|
||||
SELECT_ROUNDROBIN = iota
|
||||
SELECT_LEASTCONN
|
||||
SELECT_HASH
|
||||
SELECT_WEITHT
|
||||
SELECT_LEASTTIME
|
||||
)
|
||||
|
||||
type Selector interface {
|
||||
Select(srcAddr string) (addr string)
|
||||
SelectBackend(srcAddr string) (b *Backend)
|
||||
IncreasConns(addr string)
|
||||
DecreaseConns(addr string)
|
||||
Stop()
|
||||
Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger)
|
||||
IsActive() bool
|
||||
ActiveCount() (count int)
|
||||
Backends() (bs []*Backend)
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
selector *Selector
|
||||
log *log.Logger
|
||||
dr *dnsx.DomainResolver
|
||||
lock *sync.Mutex
|
||||
last *Backend
|
||||
debug bool
|
||||
}
|
||||
|
||||
func NewGroup(selectType int, configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger, debug bool) Group {
|
||||
bks := []*Backend{}
|
||||
for _, c := range configs {
|
||||
b, _ := NewBackend(*c, dr, log)
|
||||
bks = append(bks, b)
|
||||
}
|
||||
if len(bks) > 1 {
|
||||
for _, b := range bks {
|
||||
b.StartHeartCheck()
|
||||
}
|
||||
}
|
||||
var s Selector
|
||||
switch selectType {
|
||||
case SELECT_ROUNDROBIN:
|
||||
s = NewRoundRobin(bks, log, debug)
|
||||
case SELECT_LEASTCONN:
|
||||
s = NewLeastConn(bks, log, debug)
|
||||
case SELECT_HASH:
|
||||
s = NewHash(bks, log, debug)
|
||||
case SELECT_WEITHT:
|
||||
s = NewWeight(bks, log, debug)
|
||||
case SELECT_LEASTTIME:
|
||||
s = NewLeastTime(bks, log, debug)
|
||||
}
|
||||
return Group{
|
||||
selector: &s,
|
||||
log: log,
|
||||
dr: dr,
|
||||
lock: &sync.Mutex{},
|
||||
debug: debug,
|
||||
}
|
||||
}
|
||||
func (g *Group) Select(srcAddr string, onlyHa bool) (addr string) {
|
||||
if onlyHa {
|
||||
g.lock.Lock()
|
||||
defer g.lock.Unlock()
|
||||
if g.last != nil && (g.last.Active || g.last.ConnectUsedMillisecond == 0) {
|
||||
if g.debug {
|
||||
g.log.Printf("############ choosed %s from lastest ############", g.last.Address)
|
||||
printDebug(true, g.log, nil, srcAddr, (*g.selector).Backends())
|
||||
}
|
||||
return g.last.Address
|
||||
}
|
||||
g.last = (*g.selector).SelectBackend(srcAddr)
|
||||
if !g.last.Active && g.last.ConnectUsedMillisecond > 0 {
|
||||
g.log.Printf("###warn### lb selected empty , return default , for : %s", srcAddr)
|
||||
}
|
||||
return g.last.Address
|
||||
}
|
||||
b := (*g.selector).SelectBackend(srcAddr)
|
||||
return b.Address
|
||||
|
||||
}
|
||||
func (g *Group) IncreasConns(addr string) {
|
||||
(*g.selector).IncreasConns(addr)
|
||||
}
|
||||
func (g *Group) DecreaseConns(addr string) {
|
||||
(*g.selector).DecreaseConns(addr)
|
||||
}
|
||||
func (g *Group) Stop() {
|
||||
if g.selector != nil {
|
||||
(*g.selector).Stop()
|
||||
}
|
||||
}
|
||||
func (g *Group) IsActive() bool {
|
||||
return (*g.selector).IsActive()
|
||||
}
|
||||
func (g *Group) ActiveCount() (count int) {
|
||||
return (*g.selector).ActiveCount()
|
||||
}
|
||||
func (g *Group) Reset(addrs []string) {
|
||||
bks := (*g.selector).Backends()
|
||||
if len(bks) == 0 {
|
||||
return
|
||||
}
|
||||
cfg := bks[0].BackendConfig
|
||||
configs := BackendsConfig{}
|
||||
for _, addr := range addrs {
|
||||
c := cfg
|
||||
c.Address = addr
|
||||
configs = append(configs, &c)
|
||||
}
|
||||
(*g.selector).Reset(configs, g.dr, g.log)
|
||||
}
|
||||
func (g *Group) Backends() []*Backend {
|
||||
return (*g.selector).Backends()
|
||||
}
|
||||
|
||||
//########################RoundRobin##########################
|
||||
type RoundRobin struct {
|
||||
sync.Mutex
|
||||
backendIndex int
|
||||
backends Backends
|
||||
log *log.Logger
|
||||
debug bool
|
||||
}
|
||||
|
||||
func NewRoundRobin(backends Backends, log *log.Logger, debug bool) Selector {
|
||||
return &RoundRobin{
|
||||
backends: backends,
|
||||
log: log,
|
||||
debug: debug,
|
||||
}
|
||||
|
||||
}
|
||||
func (r *RoundRobin) Select(srcAddr string) (addr string) {
|
||||
return r.SelectBackend(srcAddr).Address
|
||||
}
|
||||
func (r *RoundRobin) SelectBackend(srcAddr string) (b *Backend) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
defer func() {
|
||||
printDebug(r.debug, r.log, b, srcAddr, r.backends)
|
||||
}()
|
||||
if len(r.backends) == 0 {
|
||||
return
|
||||
}
|
||||
if len(r.backends) == 1 {
|
||||
return r.backends[0]
|
||||
}
|
||||
RETRY:
|
||||
found := false
|
||||
for _, b := range r.backends {
|
||||
if b.Active {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return r.backends[0]
|
||||
}
|
||||
r.backendIndex++
|
||||
if r.backendIndex > len(r.backends)-1 {
|
||||
r.backendIndex = 0
|
||||
}
|
||||
if !r.backends[r.backendIndex].Active {
|
||||
goto RETRY
|
||||
}
|
||||
return r.backends[r.backendIndex]
|
||||
}
|
||||
func (r *RoundRobin) IncreasConns(addr string) {
|
||||
|
||||
}
|
||||
func (r *RoundRobin) DecreaseConns(addr string) {
|
||||
|
||||
}
|
||||
func (r *RoundRobin) Stop() {
|
||||
for _, b := range r.backends {
|
||||
b.StopHeartCheck()
|
||||
}
|
||||
}
|
||||
func (r *RoundRobin) Backends() []*Backend {
|
||||
return r.backends
|
||||
}
|
||||
func (r *RoundRobin) IsActive() bool {
|
||||
for _, b := range r.backends {
|
||||
if b.Active {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (r *RoundRobin) ActiveCount() (count int) {
|
||||
for _, b := range r.backends {
|
||||
if b.Active {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (r *RoundRobin) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.Stop()
|
||||
bks := []*Backend{}
|
||||
for _, c := range configs {
|
||||
b, _ := NewBackend(*c, dr, log)
|
||||
bks = append(bks, b)
|
||||
}
|
||||
if len(bks) > 1 {
|
||||
for _, b := range bks {
|
||||
b.StartHeartCheck()
|
||||
}
|
||||
}
|
||||
r.backends = bks
|
||||
}
|
||||
|
||||
//########################LeastConn##########################
|
||||
|
||||
type LeastConn struct {
|
||||
sync.Mutex
|
||||
backends Backends
|
||||
log *log.Logger
|
||||
debug bool
|
||||
}
|
||||
|
||||
func NewLeastConn(backends []*Backend, log *log.Logger, debug bool) Selector {
|
||||
lc := LeastConn{
|
||||
backends: backends,
|
||||
log: log,
|
||||
debug: debug,
|
||||
}
|
||||
return &lc
|
||||
}
|
||||
|
||||
func (lc *LeastConn) Select(srcAddr string) (addr string) {
|
||||
return lc.SelectBackend(srcAddr).Address
|
||||
}
|
||||
func (lc *LeastConn) SelectBackend(srcAddr string) (b *Backend) {
|
||||
lc.Lock()
|
||||
defer lc.Unlock()
|
||||
defer func() {
|
||||
printDebug(lc.debug, lc.log, b, srcAddr, lc.backends)
|
||||
}()
|
||||
if len(lc.backends) == 0 {
|
||||
return
|
||||
}
|
||||
if len(lc.backends) == 1 {
|
||||
return lc.backends[0]
|
||||
}
|
||||
found := false
|
||||
for _, b := range lc.backends {
|
||||
if b.Active {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return lc.backends[0]
|
||||
}
|
||||
min := lc.backends[0].Connections
|
||||
index := 0
|
||||
for i, b := range lc.backends {
|
||||
if b.Active {
|
||||
min = b.Connections
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
for i, b := range lc.backends {
|
||||
if b.Active && b.Connections <= min {
|
||||
min = b.Connections
|
||||
index = i
|
||||
}
|
||||
}
|
||||
return lc.backends[index]
|
||||
}
|
||||
func (lc *LeastConn) IncreasConns(addr string) {
|
||||
for _, a := range lc.backends {
|
||||
if a.Address == addr {
|
||||
a.IncreasConns()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
func (lc *LeastConn) DecreaseConns(addr string) {
|
||||
for _, a := range lc.backends {
|
||||
if a.Address == addr {
|
||||
a.DecreaseConns()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
func (lc *LeastConn) Stop() {
|
||||
for _, b := range lc.backends {
|
||||
b.StopHeartCheck()
|
||||
}
|
||||
}
|
||||
func (lc *LeastConn) IsActive() bool {
|
||||
for _, b := range lc.backends {
|
||||
if b.Active {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (lc *LeastConn) ActiveCount() (count int) {
|
||||
for _, b := range lc.backends {
|
||||
if b.Active {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (lc *LeastConn) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
|
||||
lc.Lock()
|
||||
defer lc.Unlock()
|
||||
lc.Stop()
|
||||
bks := []*Backend{}
|
||||
for _, c := range configs {
|
||||
b, _ := NewBackend(*c, dr, log)
|
||||
bks = append(bks, b)
|
||||
}
|
||||
if len(bks) > 1 {
|
||||
for _, b := range bks {
|
||||
b.StartHeartCheck()
|
||||
}
|
||||
}
|
||||
lc.backends = bks
|
||||
}
|
||||
func (lc *LeastConn) Backends() []*Backend {
|
||||
return lc.backends
|
||||
}
|
||||
|
||||
//########################Hash##########################
|
||||
type Hash struct {
|
||||
sync.Mutex
|
||||
backends Backends
|
||||
log *log.Logger
|
||||
debug bool
|
||||
}
|
||||
|
||||
func NewHash(backends Backends, log *log.Logger, debug bool) Selector {
|
||||
return &Hash{
|
||||
backends: backends,
|
||||
log: log,
|
||||
debug: debug,
|
||||
}
|
||||
}
|
||||
func (h *Hash) Select(srcAddr string) (addr string) {
|
||||
return h.SelectBackend(srcAddr).Address
|
||||
}
|
||||
func (h *Hash) SelectBackend(srcAddr string) (b *Backend) {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
defer func() {
|
||||
printDebug(h.debug, h.log, b, srcAddr, h.backends)
|
||||
}()
|
||||
if len(h.backends) == 0 {
|
||||
return
|
||||
}
|
||||
if len(h.backends) == 1 {
|
||||
return h.backends[0]
|
||||
}
|
||||
i := 0
|
||||
host, _, err := net.SplitHostPort(srcAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
//porti, _ := strconv.Atoi(port)
|
||||
//i += porti
|
||||
for _, b := range md5.Sum([]byte(host)) {
|
||||
i += int(b)
|
||||
}
|
||||
RETRY:
|
||||
found := false
|
||||
for _, b := range h.backends {
|
||||
if b.Active {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return h.backends[0]
|
||||
}
|
||||
k := i % len(h.backends)
|
||||
if !h.backends[k].Active {
|
||||
i++
|
||||
goto RETRY
|
||||
}
|
||||
return h.backends[k]
|
||||
}
|
||||
func (h *Hash) IncreasConns(addr string) {
|
||||
|
||||
}
|
||||
func (h *Hash) DecreaseConns(addr string) {
|
||||
|
||||
}
|
||||
func (h *Hash) Stop() {
|
||||
for _, b := range h.backends {
|
||||
b.StopHeartCheck()
|
||||
}
|
||||
}
|
||||
func (h *Hash) IsActive() bool {
|
||||
for _, b := range h.backends {
|
||||
if b.Active {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (h *Hash) ActiveCount() (count int) {
|
||||
for _, b := range h.backends {
|
||||
if b.Active {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (h *Hash) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
h.Stop()
|
||||
bks := []*Backend{}
|
||||
for _, c := range configs {
|
||||
b, _ := NewBackend(*c, dr, log)
|
||||
bks = append(bks, b)
|
||||
}
|
||||
if len(bks) > 1 {
|
||||
for _, b := range bks {
|
||||
b.StartHeartCheck()
|
||||
}
|
||||
}
|
||||
h.backends = bks
|
||||
}
|
||||
func (h *Hash) Backends() []*Backend {
|
||||
return h.backends
|
||||
}
|
||||
|
||||
//########################Weight##########################
|
||||
type Weight struct {
|
||||
sync.Mutex
|
||||
backends Backends
|
||||
log *log.Logger
|
||||
debug bool
|
||||
}
|
||||
|
||||
func NewWeight(backends Backends, log *log.Logger, debug bool) Selector {
|
||||
return &Weight{
|
||||
backends: backends,
|
||||
log: log,
|
||||
debug: debug,
|
||||
}
|
||||
}
|
||||
func (w *Weight) Select(srcAddr string) (addr string) {
|
||||
return w.SelectBackend(srcAddr).Address
|
||||
}
|
||||
func (w *Weight) SelectBackend(srcAddr string) (b *Backend) {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
defer func() {
|
||||
printDebug(w.debug, w.log, b, srcAddr, w.backends)
|
||||
}()
|
||||
if len(w.backends) == 0 {
|
||||
return
|
||||
}
|
||||
if len(w.backends) == 1 {
|
||||
return w.backends[0]
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, b := range w.backends {
|
||||
if b.Active {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return w.backends[0]
|
||||
}
|
||||
|
||||
min := w.backends[0].Connections / w.backends[0].Weight
|
||||
index := 0
|
||||
for i, b := range w.backends {
|
||||
if b.Active {
|
||||
min = b.Connections / b.Weight
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
for i, b := range w.backends {
|
||||
if b.Active && b.Connections/b.Weight <= min {
|
||||
min = b.Connections
|
||||
index = i
|
||||
}
|
||||
}
|
||||
return w.backends[index]
|
||||
}
|
||||
func (w *Weight) IncreasConns(addr string) {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
for _, a := range w.backends {
|
||||
if a.Address == addr {
|
||||
a.IncreasConns()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
func (w *Weight) DecreaseConns(addr string) {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
for _, a := range w.backends {
|
||||
if a.Address == addr {
|
||||
a.DecreaseConns()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
func (w *Weight) Stop() {
|
||||
for _, b := range w.backends {
|
||||
b.StopHeartCheck()
|
||||
}
|
||||
}
|
||||
func (w *Weight) IsActive() bool {
|
||||
for _, b := range w.backends {
|
||||
if b.Active {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (w *Weight) ActiveCount() (count int) {
|
||||
for _, b := range w.backends {
|
||||
if b.Active {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (w *Weight) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
w.Stop()
|
||||
bks := []*Backend{}
|
||||
for _, c := range configs {
|
||||
b, _ := NewBackend(*c, dr, log)
|
||||
bks = append(bks, b)
|
||||
}
|
||||
if len(bks) > 1 {
|
||||
for _, b := range bks {
|
||||
b.StartHeartCheck()
|
||||
}
|
||||
}
|
||||
w.backends = bks
|
||||
}
|
||||
func (w *Weight) Backends() []*Backend {
|
||||
return w.backends
|
||||
}
|
||||
|
||||
//########################LeastTime##########################
|
||||
|
||||
type LeastTime struct {
|
||||
sync.Mutex
|
||||
backends Backends
|
||||
log *log.Logger
|
||||
debug bool
|
||||
}
|
||||
|
||||
func NewLeastTime(backends []*Backend, log *log.Logger, debug bool) Selector {
|
||||
lt := LeastTime{
|
||||
backends: backends,
|
||||
log: log,
|
||||
debug: debug,
|
||||
}
|
||||
return <
|
||||
}
|
||||
|
||||
func (lt *LeastTime) Select(srcAddr string) (addr string) {
|
||||
return lt.SelectBackend(srcAddr).Address
|
||||
}
|
||||
func (lt *LeastTime) SelectBackend(srcAddr string) (b *Backend) {
|
||||
lt.Lock()
|
||||
defer lt.Unlock()
|
||||
defer func() {
|
||||
printDebug(lt.debug, lt.log, b, srcAddr, lt.backends)
|
||||
}()
|
||||
if len(lt.backends) == 0 {
|
||||
return
|
||||
}
|
||||
if len(lt.backends) == 1 {
|
||||
return lt.backends[0]
|
||||
}
|
||||
found := false
|
||||
for _, b := range lt.backends {
|
||||
if b.Active {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return lt.backends[0]
|
||||
}
|
||||
min := lt.backends[0].ConnectUsedMillisecond
|
||||
index := 0
|
||||
for i, b := range lt.backends {
|
||||
if b.Active {
|
||||
min = b.ConnectUsedMillisecond
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
for i, b := range lt.backends {
|
||||
if b.Active && b.ConnectUsedMillisecond > 0 && b.ConnectUsedMillisecond <= min {
|
||||
min = b.ConnectUsedMillisecond
|
||||
index = i
|
||||
}
|
||||
}
|
||||
return lt.backends[index]
|
||||
}
|
||||
func (lt *LeastTime) IncreasConns(addr string) {
|
||||
|
||||
}
|
||||
func (lt *LeastTime) DecreaseConns(addr string) {
|
||||
|
||||
}
|
||||
func (lt *LeastTime) Stop() {
|
||||
for _, b := range lt.backends {
|
||||
b.StopHeartCheck()
|
||||
}
|
||||
}
|
||||
func (lt *LeastTime) IsActive() bool {
|
||||
for _, b := range lt.backends {
|
||||
if b.Active {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (lt *LeastTime) ActiveCount() (count int) {
|
||||
for _, b := range lt.backends {
|
||||
if b.Active {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (lt *LeastTime) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
|
||||
lt.Lock()
|
||||
defer lt.Unlock()
|
||||
lt.Stop()
|
||||
bks := []*Backend{}
|
||||
for _, c := range configs {
|
||||
b, _ := NewBackend(*c, dr, log)
|
||||
bks = append(bks, b)
|
||||
}
|
||||
if len(bks) > 1 {
|
||||
for _, b := range bks {
|
||||
b.StartHeartCheck()
|
||||
}
|
||||
}
|
||||
lt.backends = bks
|
||||
}
|
||||
func (lt *LeastTime) Backends() []*Backend {
|
||||
return lt.backends
|
||||
}
|
||||
func printDebug(isDebug bool, log *log.Logger, selected *Backend, srcAddr string, backends []*Backend) {
|
||||
if isDebug {
|
||||
log.Printf("############ LB start ############\n")
|
||||
if selected != nil {
|
||||
log.Printf("choosed %s for %s\n", selected.Address, srcAddr)
|
||||
}
|
||||
for _, v := range backends {
|
||||
log.Printf("addr:%s,conns:%d,time:%d,weight:%d,active:%v\n", v.Address, v.Connections, v.ConnectUsedMillisecond, v.Weight, v.Active)
|
||||
}
|
||||
log.Printf("############ LB end ############\n")
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
package utils
|
||||
package mapx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
193
utils/ss/conn.go
Normal file
193
utils/ss/conn.go
Normal file
@ -0,0 +1,193 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
OneTimeAuthMask byte = 0x10
|
||||
AddrMask byte = 0xf
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
*Cipher
|
||||
readBuf []byte
|
||||
writeBuf []byte
|
||||
chunkId uint32
|
||||
}
|
||||
|
||||
func NewConn(c net.Conn, cipher *Cipher) *Conn {
|
||||
return &Conn{
|
||||
Conn: c,
|
||||
Cipher: cipher,
|
||||
readBuf: leakyBuf.Get(),
|
||||
writeBuf: leakyBuf.Get()}
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
leakyBuf.Put(c.readBuf)
|
||||
leakyBuf.Put(c.writeBuf)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func RawAddr(addr string) (buf []byte, err error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ss: address error %s %v", addr, err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ss: invalid port %s", addr)
|
||||
}
|
||||
|
||||
hostLen := len(host)
|
||||
l := 1 + 1 + hostLen + 2 // addrType + lenByte + address + port
|
||||
buf = make([]byte, l)
|
||||
buf[0] = 3 // 3 means the address is domain name
|
||||
buf[1] = byte(hostLen) // host address length followed by host address
|
||||
copy(buf[2:], host)
|
||||
binary.BigEndian.PutUint16(buf[2+hostLen:2+hostLen+2], uint16(port))
|
||||
return
|
||||
}
|
||||
|
||||
// This is intended for use by users implementing a local socks proxy.
|
||||
// rawaddr shoud contain part of the data in socks request, starting from the
|
||||
// ATYP field. (Refer to rfc1928 for more information.)
|
||||
func DialWithRawAddr(rawConn *net.Conn, rawaddr []byte, server string, cipher *Cipher) (c *Conn, err error) {
|
||||
var conn net.Conn
|
||||
if rawConn == nil {
|
||||
conn, err = net.Dial("tcp", server)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if rawConn != nil {
|
||||
c = NewConn(*rawConn, cipher)
|
||||
} else {
|
||||
c = NewConn(conn, cipher)
|
||||
}
|
||||
if cipher.ota {
|
||||
if c.enc == nil {
|
||||
if _, err = c.initEncrypt(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// since we have initEncrypt, we must send iv manually
|
||||
conn.Write(cipher.iv)
|
||||
rawaddr[0] |= OneTimeAuthMask
|
||||
rawaddr = otaConnectAuth(cipher.iv, cipher.key, rawaddr)
|
||||
}
|
||||
if _, err = c.write(rawaddr); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// addr should be in the form of host:port
|
||||
func Dial(addr, server string, cipher *Cipher) (c *Conn, err error) {
|
||||
ra, err := RawAddr(addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return DialWithRawAddr(nil, ra, server, cipher)
|
||||
}
|
||||
|
||||
func (c *Conn) GetIv() (iv []byte) {
|
||||
iv = make([]byte, len(c.iv))
|
||||
copy(iv, c.iv)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) GetKey() (key []byte) {
|
||||
key = make([]byte, len(c.key))
|
||||
copy(key, c.key)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) IsOta() bool {
|
||||
return c.ota
|
||||
}
|
||||
|
||||
func (c *Conn) GetAndIncrChunkId() (chunkId uint32) {
|
||||
chunkId = c.chunkId
|
||||
c.chunkId += 1
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
if c.dec == nil {
|
||||
iv := make([]byte, c.info.ivLen)
|
||||
if _, err = io.ReadFull(c.Conn, iv); err != nil {
|
||||
return
|
||||
}
|
||||
if err = c.initDecrypt(iv); err != nil {
|
||||
return
|
||||
}
|
||||
if len(c.iv) == 0 {
|
||||
c.iv = iv
|
||||
}
|
||||
}
|
||||
|
||||
cipherData := c.readBuf
|
||||
if len(b) > len(cipherData) {
|
||||
cipherData = make([]byte, len(b))
|
||||
} else {
|
||||
cipherData = cipherData[:len(b)]
|
||||
}
|
||||
|
||||
n, err = c.Conn.Read(cipherData)
|
||||
if n > 0 {
|
||||
c.decrypt(b[0:n], cipherData[0:n])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
nn := len(b)
|
||||
if c.ota {
|
||||
chunkId := c.GetAndIncrChunkId()
|
||||
b = otaReqChunkAuth(c.iv, chunkId, b)
|
||||
}
|
||||
headerLen := len(b) - nn
|
||||
|
||||
n, err = c.write(b)
|
||||
// Make sure <= 0 <= len(b), where b is the slice passed in.
|
||||
if n >= headerLen {
|
||||
n -= headerLen
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) write(b []byte) (n int, err error) {
|
||||
var iv []byte
|
||||
if c.enc == nil {
|
||||
iv, err = c.initEncrypt()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cipherData := c.writeBuf
|
||||
dataSize := len(b) + len(iv)
|
||||
if dataSize > len(cipherData) {
|
||||
cipherData = make([]byte, dataSize)
|
||||
} else {
|
||||
cipherData = cipherData[:dataSize]
|
||||
}
|
||||
|
||||
if iv != nil {
|
||||
// Put initialization vector in buffer, do a single write to send both
|
||||
// iv and data.
|
||||
copy(cipherData, iv)
|
||||
}
|
||||
|
||||
c.encrypt(cipherData[len(iv):], b)
|
||||
n, err = c.Conn.Write(cipherData)
|
||||
return
|
||||
}
|
||||
301
utils/ss/encrypt.go
Normal file
301
utils/ss/encrypt.go
Normal file
@ -0,0 +1,301 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/des"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/rc4"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/Yawning/chacha20"
|
||||
"golang.org/x/crypto/blowfish"
|
||||
"golang.org/x/crypto/cast5"
|
||||
"golang.org/x/crypto/salsa20/salsa"
|
||||
)
|
||||
|
||||
var errEmptyPassword = errors.New("empty key")
|
||||
|
||||
func md5sum(d []byte) []byte {
|
||||
h := md5.New()
|
||||
h.Write(d)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func evpBytesToKey(password string, keyLen int) (key []byte) {
|
||||
const md5Len = 16
|
||||
|
||||
cnt := (keyLen-1)/md5Len + 1
|
||||
m := make([]byte, cnt*md5Len)
|
||||
copy(m, md5sum([]byte(password)))
|
||||
|
||||
// Repeatedly call md5 until bytes generated is enough.
|
||||
// Each call to md5 uses data: prev md5 sum + password.
|
||||
d := make([]byte, md5Len+len(password))
|
||||
start := 0
|
||||
for i := 1; i < cnt; i++ {
|
||||
start += md5Len
|
||||
copy(d, m[start-md5Len:start])
|
||||
copy(d[md5Len:], password)
|
||||
copy(m[start:], md5sum(d))
|
||||
}
|
||||
return m[:keyLen]
|
||||
}
|
||||
|
||||
type DecOrEnc int
|
||||
|
||||
const (
|
||||
Decrypt DecOrEnc = iota
|
||||
Encrypt
|
||||
)
|
||||
|
||||
func newStream(block cipher.Block, err error, key, iv []byte,
|
||||
doe DecOrEnc) (cipher.Stream, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if doe == Encrypt {
|
||||
return cipher.NewCFBEncrypter(block, iv), nil
|
||||
} else {
|
||||
return cipher.NewCFBDecrypter(block, iv), nil
|
||||
}
|
||||
}
|
||||
|
||||
func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cipher.NewCTR(block, iv), nil
|
||||
}
|
||||
|
||||
func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := des.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newBlowFishStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := blowfish.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newCast5Stream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
|
||||
block, err := cast5.NewCipher(key)
|
||||
return newStream(block, err, key, iv, doe)
|
||||
}
|
||||
|
||||
func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
|
||||
h := md5.New()
|
||||
h.Write(key)
|
||||
h.Write(iv)
|
||||
rc4key := h.Sum(nil)
|
||||
|
||||
return rc4.NewCipher(rc4key)
|
||||
}
|
||||
|
||||
func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
|
||||
return chacha20.NewCipher(key, iv)
|
||||
}
|
||||
|
||||
func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
|
||||
return chacha20.NewCipher(key, iv)
|
||||
}
|
||||
|
||||
type salsaStreamCipher struct {
|
||||
nonce [8]byte
|
||||
key [32]byte
|
||||
counter int
|
||||
}
|
||||
|
||||
func (c *salsaStreamCipher) XORKeyStream(dst, src []byte) {
|
||||
var buf []byte
|
||||
padLen := c.counter % 64
|
||||
dataSize := len(src) + padLen
|
||||
if cap(dst) >= dataSize {
|
||||
buf = dst[:dataSize]
|
||||
} else if leakyBufSize >= dataSize {
|
||||
buf = leakyBuf.Get()
|
||||
defer leakyBuf.Put(buf)
|
||||
buf = buf[:dataSize]
|
||||
} else {
|
||||
buf = make([]byte, dataSize)
|
||||
}
|
||||
|
||||
var subNonce [16]byte
|
||||
copy(subNonce[:], c.nonce[:])
|
||||
binary.LittleEndian.PutUint64(subNonce[len(c.nonce):], uint64(c.counter/64))
|
||||
|
||||
// It's difficult to avoid data copy here. src or dst maybe slice from
|
||||
// Conn.Read/Write, which can't have padding.
|
||||
copy(buf[padLen:], src[:])
|
||||
salsa.XORKeyStream(buf, buf, &subNonce, &c.key)
|
||||
copy(dst, buf[padLen:])
|
||||
|
||||
c.counter += len(src)
|
||||
}
|
||||
|
||||
func newSalsa20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
|
||||
var c salsaStreamCipher
|
||||
copy(c.nonce[:], iv[:8])
|
||||
copy(c.key[:], key[:32])
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type cipherInfo struct {
|
||||
keyLen int
|
||||
ivLen int
|
||||
newStream func(key, iv []byte, doe DecOrEnc) (cipher.Stream, error)
|
||||
}
|
||||
|
||||
var cipherMethod = map[string]*cipherInfo{
|
||||
"aes-128-cfb": {16, 16, newAESCFBStream},
|
||||
"aes-192-cfb": {24, 16, newAESCFBStream},
|
||||
"aes-256-cfb": {32, 16, newAESCFBStream},
|
||||
"aes-128-ctr": {16, 16, newAESCTRStream},
|
||||
"aes-192-ctr": {24, 16, newAESCTRStream},
|
||||
"aes-256-ctr": {32, 16, newAESCTRStream},
|
||||
"des-cfb": {8, 8, newDESStream},
|
||||
"bf-cfb": {16, 8, newBlowFishStream},
|
||||
"cast5-cfb": {16, 8, newCast5Stream},
|
||||
"rc4-md5": {16, 16, newRC4MD5Stream},
|
||||
"rc4-md5-6": {16, 6, newRC4MD5Stream},
|
||||
"chacha20": {32, 8, newChaCha20Stream},
|
||||
"chacha20-ietf": {32, 12, newChaCha20IETFStream},
|
||||
"salsa20": {32, 8, newSalsa20Stream},
|
||||
}
|
||||
|
||||
func CheckCipherMethod(method string) error {
|
||||
if method == "" {
|
||||
method = "aes-256-cfb"
|
||||
}
|
||||
_, ok := cipherMethod[method]
|
||||
if !ok {
|
||||
return errors.New("Unsupported encryption method: " + method)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Cipher struct {
|
||||
enc cipher.Stream
|
||||
dec cipher.Stream
|
||||
key []byte
|
||||
info *cipherInfo
|
||||
ota bool // one-time auth
|
||||
iv []byte
|
||||
}
|
||||
|
||||
// NewCipher creates a cipher that can be used in Dial() etc.
|
||||
// Use cipher.Copy() to create a new cipher with the same method and password
|
||||
// to avoid the cost of repeated cipher initialization.
|
||||
func NewCipher(method, password string) (c *Cipher, err error) {
|
||||
if password == "" {
|
||||
return nil, errEmptyPassword
|
||||
}
|
||||
var ota bool
|
||||
if strings.HasSuffix(strings.ToLower(method), "-auth") {
|
||||
method = method[:len(method)-5] // len("-auth") = 5
|
||||
ota = true
|
||||
} else {
|
||||
ota = false
|
||||
}
|
||||
mi, ok := cipherMethod[method]
|
||||
if !ok {
|
||||
return nil, errors.New("Unsupported encryption method: " + method)
|
||||
}
|
||||
|
||||
key := evpBytesToKey(password, mi.keyLen)
|
||||
|
||||
c = &Cipher{key: key, info: mi}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.ota = ota
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Initializes the block cipher with CFB mode, returns IV.
|
||||
func (c *Cipher) initEncrypt() (iv []byte, err error) {
|
||||
if c.iv == nil {
|
||||
iv = make([]byte, c.info.ivLen)
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.iv = iv
|
||||
} else {
|
||||
iv = c.iv
|
||||
}
|
||||
c.enc, err = c.info.newStream(c.key, iv, Encrypt)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Cipher) initDecrypt(iv []byte) (err error) {
|
||||
c.dec, err = c.info.newStream(c.key, iv, Decrypt)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Cipher) encrypt(dst, src []byte) {
|
||||
c.enc.XORKeyStream(dst, src)
|
||||
}
|
||||
|
||||
func (c *Cipher) decrypt(dst, src []byte) {
|
||||
c.dec.XORKeyStream(dst, src)
|
||||
}
|
||||
func (c *Cipher) Encrypt(src []byte) (cipherData []byte) {
|
||||
cipher := c.Copy()
|
||||
iv, err := cipher.initEncrypt()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
packetLen := len(src) + len(iv)
|
||||
cipherData = make([]byte, packetLen)
|
||||
copy(cipherData, iv)
|
||||
cipher.encrypt(cipherData[len(iv):], src)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Cipher) Decrypt(src []byte) (data []byte) {
|
||||
cipher := c.Copy()
|
||||
if len(src) < c.info.ivLen {
|
||||
return
|
||||
}
|
||||
iv := make([]byte, c.info.ivLen)
|
||||
copy(iv, src[:c.info.ivLen])
|
||||
if err := cipher.initDecrypt(iv); err != nil {
|
||||
return
|
||||
}
|
||||
data = make([]byte, len(src)-len(iv))
|
||||
cipher.decrypt(data[0:], src[c.info.ivLen:])
|
||||
return
|
||||
}
|
||||
|
||||
// Copy creates a new cipher at it's initial state.
|
||||
func (c *Cipher) Copy() *Cipher {
|
||||
// This optimization maybe not necessary. But without this function, we
|
||||
// need to maintain a table cache for newTableCipher and use lock to
|
||||
// protect concurrent access to that cache.
|
||||
|
||||
// AES and DES ciphers does not return specific types, so it's difficult
|
||||
// to create copy. But their initizliation time is less than 4000ns on my
|
||||
// 2.26 GHz Intel Core 2 Duo processor. So no need to worry.
|
||||
|
||||
// Currently, blow-fish and cast5 initialization cost is an order of
|
||||
// maganitude slower than other ciphers. (I'm not sure whether this is
|
||||
// because the current implementation is not highly optimized, or this is
|
||||
// the nature of the algorithm.)
|
||||
|
||||
nc := *c
|
||||
nc.enc = nil
|
||||
nc.dec = nil
|
||||
nc.ota = c.ota
|
||||
return &nc
|
||||
}
|
||||
131
utils/ss/util.go
Normal file
131
utils/ss/util.go
Normal file
@ -0,0 +1,131 @@
|
||||
package ss
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"bitbucket.org/snail/proxy/utils"
|
||||
)
|
||||
|
||||
const leakyBufSize = 4108 // data.len(2) + hmacsha1(10) + data(4096)
|
||||
const maxNBuf = 2048
|
||||
|
||||
var leakyBuf = utils.NewLeakyBuf(maxNBuf, leakyBufSize)
|
||||
|
||||
func IsFileExists(path string) (bool, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err == nil {
|
||||
if stat.Mode()&os.ModeType == 0 {
|
||||
return true, nil
|
||||
}
|
||||
return false, errors.New(path + " exists but is not regular file")
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func HmacSha1(key []byte, data []byte) []byte {
|
||||
hmacSha1 := hmac.New(sha1.New, key)
|
||||
hmacSha1.Write(data)
|
||||
return hmacSha1.Sum(nil)[:10]
|
||||
}
|
||||
|
||||
func otaConnectAuth(iv, key, data []byte) []byte {
|
||||
return append(data, HmacSha1(append(iv, key...), data)...)
|
||||
}
|
||||
|
||||
func otaReqChunkAuth(iv []byte, chunkId uint32, data []byte) []byte {
|
||||
nb := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(nb, uint16(len(data)))
|
||||
chunkIdBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(chunkIdBytes, chunkId)
|
||||
header := append(nb, HmacSha1(append(iv, chunkIdBytes...), data)...)
|
||||
return append(header, data...)
|
||||
}
|
||||
|
||||
const (
|
||||
idType = 0 // address type index
|
||||
idIP0 = 1 // ip addres start index
|
||||
idDmLen = 1 // domain address length index
|
||||
idDm0 = 2 // domain address start index
|
||||
|
||||
typeIPv4 = 1 // type is ipv4 address
|
||||
typeDm = 3 // type is domain address
|
||||
typeIPv6 = 4 // type is ipv6 address
|
||||
|
||||
lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
|
||||
lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
|
||||
lenDmBase = 2 // 1addrLen + 2port, plus addrLen
|
||||
lenHmacSha1 = 10
|
||||
)
|
||||
|
||||
func GetRequest(conn *Conn) (host string, err error) {
|
||||
|
||||
// buf size should at least have the same size with the largest possible
|
||||
// request size (when addrType is 3, domain name has at most 256 bytes)
|
||||
// 1(addrType) + 1(lenByte) + 255(max length address) + 2(port) + 10(hmac-sha1)
|
||||
buf := make([]byte, 269)
|
||||
// read till we get possible domain length field
|
||||
if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var reqStart, reqEnd int
|
||||
addrType := buf[idType]
|
||||
switch addrType & AddrMask {
|
||||
case typeIPv4:
|
||||
reqStart, reqEnd = idIP0, idIP0+lenIPv4
|
||||
case typeIPv6:
|
||||
reqStart, reqEnd = idIP0, idIP0+lenIPv6
|
||||
case typeDm:
|
||||
if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
|
||||
return
|
||||
}
|
||||
reqStart, reqEnd = idDm0, idDm0+int(buf[idDmLen])+lenDmBase
|
||||
default:
|
||||
err = fmt.Errorf("addr type %d not supported", addrType&AddrMask)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Return string for typeIP is not most efficient, but browsers (Chrome,
|
||||
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
|
||||
// big problem.
|
||||
switch addrType & AddrMask {
|
||||
case typeIPv4:
|
||||
host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
|
||||
case typeIPv6:
|
||||
host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String()
|
||||
case typeDm:
|
||||
host = string(buf[idDm0 : idDm0+int(buf[idDmLen])])
|
||||
}
|
||||
// parse port
|
||||
port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
|
||||
host = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type ClosedFlag struct {
|
||||
flag bool
|
||||
}
|
||||
|
||||
func (flag *ClosedFlag) SetClosed() {
|
||||
flag.flag = true
|
||||
}
|
||||
|
||||
func (flag *ClosedFlag) IsClosed() bool {
|
||||
return flag.flag
|
||||
}
|
||||
Reference in New Issue
Block a user