179 lines
4.0 KiB
Go
179 lines
4.0 KiB
Go
package iolimiter
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
const burstLimit = 1000 * 1000 * 1000
|
|
|
|
type Reader struct {
|
|
r io.Reader
|
|
limiter *rate.Limiter
|
|
ctx context.Context
|
|
}
|
|
|
|
type Writer struct {
|
|
w io.Writer
|
|
limiter *rate.Limiter
|
|
ctx context.Context
|
|
}
|
|
|
|
type conn struct {
|
|
net.Conn
|
|
r io.Reader
|
|
w io.Writer
|
|
readLimiter *rate.Limiter
|
|
writeLimiter *rate.Limiter
|
|
ctx context.Context
|
|
}
|
|
|
|
//NewtRateLimitConn sets rate limit (bytes/sec) to the Conn read and write.
|
|
func NewtConn(c net.Conn, bytesPerSec float64) net.Conn {
|
|
s := &conn{
|
|
Conn: c,
|
|
r: c,
|
|
w: c,
|
|
ctx: context.Background(),
|
|
}
|
|
s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
|
s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
|
s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
|
s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
|
return s
|
|
}
|
|
|
|
//NewtRateLimitReaderConn sets rate limit (bytes/sec) to the Conn read.
|
|
func NewReaderConn(c net.Conn, bytesPerSec float64) net.Conn {
|
|
s := &conn{
|
|
Conn: c,
|
|
r: c,
|
|
w: c,
|
|
ctx: context.Background(),
|
|
}
|
|
s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
|
s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
|
return s
|
|
}
|
|
|
|
//NewtRateLimitWriterConn sets rate limit (bytes/sec) to the Conn write.
|
|
func NewWriterConn(c net.Conn, bytesPerSec float64) net.Conn {
|
|
s := &conn{
|
|
Conn: c,
|
|
r: c,
|
|
w: c,
|
|
ctx: context.Background(),
|
|
}
|
|
s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
|
s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
|
return s
|
|
}
|
|
|
|
// Read reads bytes into p.
|
|
func (s *conn) Read(p []byte) (int, error) {
|
|
if s.readLimiter == nil {
|
|
return s.r.Read(p)
|
|
}
|
|
n, err := s.r.Read(p)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
if err := s.readLimiter.WaitN(s.ctx, n); err != nil {
|
|
return n, err
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// Write writes bytes from p.
|
|
func (s *conn) Write(p []byte) (int, error) {
|
|
if s.writeLimiter == nil {
|
|
return s.w.Write(p)
|
|
}
|
|
n, err := s.w.Write(p)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
if err := s.writeLimiter.WaitN(s.ctx, n); err != nil {
|
|
return n, err
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// NewReader returns a reader that implements io.Reader with rate limiting.
|
|
func NewReader(r io.Reader) *Reader {
|
|
return &Reader{
|
|
r: r,
|
|
ctx: context.Background(),
|
|
}
|
|
}
|
|
|
|
// NewReaderWithContext returns a reader that implements io.Reader with rate limiting.
|
|
func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader {
|
|
return &Reader{
|
|
r: r,
|
|
ctx: ctx,
|
|
}
|
|
}
|
|
|
|
// NewWriter returns a writer that implements io.Writer with rate limiting.
|
|
func NewWriter(w io.Writer) *Writer {
|
|
return &Writer{
|
|
w: w,
|
|
ctx: context.Background(),
|
|
}
|
|
}
|
|
|
|
// NewWriterWithContext returns a writer that implements io.Writer with rate limiting.
|
|
func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer {
|
|
return &Writer{
|
|
w: w,
|
|
ctx: ctx,
|
|
}
|
|
}
|
|
|
|
// SetRateLimit sets rate limit (bytes/sec) to the reader.
|
|
func (s *Reader) SetRateLimit(bytesPerSec float64) {
|
|
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
|
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
|
}
|
|
|
|
// Read reads bytes into p.
|
|
func (s *Reader) Read(p []byte) (int, error) {
|
|
if s.limiter == nil {
|
|
return s.r.Read(p)
|
|
}
|
|
n, err := s.r.Read(p)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
|
return n, err
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// SetRateLimit sets rate limit (bytes/sec) to the writer.
|
|
func (s *Writer) SetRateLimit(bytesPerSec float64) {
|
|
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
|
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
|
}
|
|
|
|
// Write writes bytes from p.
|
|
func (s *Writer) Write(p []byte) (int, error) {
|
|
if s.limiter == nil {
|
|
return s.w.Write(p)
|
|
}
|
|
n, err := s.w.Write(p)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
|
return n, err
|
|
}
|
|
return n, err
|
|
}
|