diff --git a/CHANGELOG b/CHANGELOG index 4816624..1251942 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,23 +1,28 @@ proxy更新日志: +v3.0 +1.增加了代理死循环检查,增强了安全性。 +2.重构了全部代码,为下一步的功能拓展做准备。 +3.此次更新不兼容2.x版本。 + v2.2 -1.增加了强制使用上级代理参数always.可以使所有流量都走上级代理. -2.增加了定时检查网络是否正常,可以在本地网络不稳定的时候修复连接池状态,提升代理访问体验. +1.增加了强制使用上级代理参数always.可以使所有流量都走上级代理。 +2.增加了定时检查网络是否正常,可以在本地网络不稳定的时候修复连接池状态,提升代理访问体验。 3.http代理增加了对ipv6地址的支持。 v2.1 -1.增加了http basic验证功能,可以对http代理协议设置basic验证,用户名和密码支持来自文件或者命令行. +1.增加了http basic验证功能,可以对http代理协议设置basic验证,用户名和密码支持来自文件或者命令行。 2.优化了域名检查方法,避免空连接的出现。 -3.修复了连接上级代理超时参数传递错误导致超时过大的问题. -4.增加了连接池状态监测,如果上级代理或者网络出现问题,会及时重新初始化连接池,防止大量无效连接,降低浏览体验. -5.增加了对系统kill信号的捕获,可以在收到系统kill信号之后执行清理释放连接的操作.避免出现大量CLOSE_WAIT. +3.修复了连接上级代理超时参数传递错误导致超时过大的问题。 +4.增加了连接池状态监测,如果上级代理或者网络出现问题,会及时重新初始化连接池,防止大量无效连接,降低浏览体验。 +5.增加了对系统kill信号的捕获,可以在收到系统kill信号之后执行清理释放连接的操作.避免出现大量CLOSE_WAIT。 v2.0 -1.增加了连接池功能,大幅提高了通过上级代理访问的速度. +1.增加了连接池功能,大幅提高了通过上级代理访问的速度。 2.HTTP代理模式,优化了请求URL的获取逻辑,可以支持:http,https,websocket -3.增加了TCP代理模式,支持是否加密通讯. -4.优化了链接关闭逻辑,避免出现大量CLOSE_WAIT. -5.增加了黑白名单机制,更自由快速的访问. -6.优化了网站Block机制检测,判断更准确. +3.增加了TCP代理模式,支持是否加密通讯。 +4.优化了链接关闭逻辑,避免出现大量CLOSE_WAIT。 +5.增加了黑白名单机制,更自由快速的访问。 +6.优化了网站Block机制检测,判断更准确。 v1.0 -1.始发版本,可以代理http,https. +1.始发版本,可以代理http,https。 diff --git a/config.go b/config.go index def80ae..1f18c94 100755 --- a/config.go +++ b/config.go @@ -1,111 +1,89 @@ package main import ( - "flag" "fmt" + "io/ioutil" "log" "os" - "strings" + "proxy/services" + "proxy/utils" - "github.com/spf13/pflag" - "github.com/spf13/viper" + kingpin "gopkg.in/alecthomas/kingpin.v2" ) var ( - cfg = viper.New() + app *kingpin.Application + service services.ServiceItem ) func initConfig() (err error) { - //define command line args + args := services.Args{} + //define args + tcpArgs := services.TCPArgs{} + httpArgs := services.HTTPArgs{} + tlsArgs := services.TLSArgs{} + udpArgs := services.UDPArgs{} - pflag.CommandLine.AddGoFlagSet(flag.CommandLine) - configFile := pflag.StringP("config", "c", "", "config file path") + //build srvice args + app = kingpin.New("proxy", "happy with proxy") + app.Author("snail").Version(APP_VERSION) + args.Parent = app.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String() + args.Local = app.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String() + certTLS := app.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String() + keyTLS := app.Flag("key", "key file for tls").Short('K').Default("proxy.key").String() + args.PoolSize = app.Flag("pool-size", "conn pool size , which connect to parent proxy, zero: means turn off pool").Default("50").Int() + args.CheckParentInterval = app.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Default("3").Int() - pflag.BoolP("parent-tls", "X", false, "parent proxy is tls") - pflag.BoolP("local-tls", "x", false, "local proxy is tls") - pflag.BoolP("parent-tcp", "W", false, "parent proxy is tcp") - pflag.BoolP("local-tcp", "w", true, "local proxy is tcp") - pflag.BoolP("parent-udp", "U", false, "parent is udp") - pflag.BoolP("local-udp", "u", false, "local proxy is udp") - version := pflag.BoolP("version", "v", false, "show version") - pflag.BoolP("local-http", "z", false, "proxy on http") - pflag.Bool("always", false, "always use parent proxy") + //########http######### + http := app.Command("http", "proxy on http mode") + httpArgs.LocalType = http.Flag("local-type", "parent protocol type ").Default("tcp").Short('t').Enum("tls", "tcp") + httpArgs.ParentType = http.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp") + httpArgs.Always = http.Flag("always", "always use parent proxy").Default("false").Bool() + httpArgs.Timeout = http.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int() + httpArgs.HTTPTimeout = http.Flag("http-timeout", "check domain if blocked , http request timeout milliseconds when connect to host").Default("3000").Int() + httpArgs.Interval = http.Flag("interval", "check domain if blocked every interval seconds").Default("10").Int() + httpArgs.Blocked = http.Flag("blocked", "blocked domain file , one domain each line").Default("blocked").Short('b').String() + httpArgs.Direct = http.Flag("direct", "direct domain file , one domain each line").Default("direct").Short('d').String() + httpArgs.AuthFile = http.Flag("auth-file", "http basic auth file,\"username:password\" each line in file").Short('F').String() + httpArgs.Auth = http.Flag("auth", "http basic auth username and password, mutiple user repeat -a ,such as: -a user1:pass1 -a user2:pass2").Short('a').Strings() - pflag.Int("check-proxy-interval", 3, "check if proxy is okay every interval seconds") - pflag.IntP("port", "p", 33080, "local port to listen") - pflag.IntP("check-timeout", "t", 3000, "chekc domain blocked , http request timeout milliseconds when connect to host") - pflag.IntP("tcp-timeout", "T", 2000, "tcp timeout milliseconds when connect to real server or parent proxy") - pflag.IntP("check-interval", "I", 10, "check domain if blocked every interval seconds") - pflag.IntP("pool-size", "s", 50, "conn pool size , which connect to parent proxy, zero: means turn off pool") + //########tcp######### + tcp := app.Command("tcp", "proxy on tcp mode") + tcpArgs.Timeout = tcp.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int() + tcpArgs.ParentType = tcp.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "udp") - pflag.StringP("parent", "P", "", "parent proxy address") - pflag.StringP("ip", "i", "0.0.0.0", "local ip to bind") - pflag.StringP("cert", "f", "proxy.crt", "cert file for tls") - pflag.StringP("key", "k", "proxy.key", "key file for tls") - pflag.StringP("blocked", "b", "blocked", "blocked domain file , one domain each line") - pflag.StringP("direct", "d", "direct", "direct domain file , one domain each line") - pflag.StringP("auth-file", "F", "", "http basic auth file,\"username:password\" each line in file") - pflag.StringSliceP("auth", "a", []string{}, "http basic auth username and password,such as: \"user1:pass1,user2:pass2\"") + //########tls######### + tls := app.Command("tls", "proxy on tls mode") + tlsArgs.Timeout = tls.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int() + tlsArgs.ParentType = tls.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "udp") - pflag.Parse() + kingpin.MustParse(app.Parse(os.Args[1:])) - cfg.BindPFlag("parent-tls", pflag.Lookup("parent-tls")) - cfg.BindPFlag("local-tls", pflag.Lookup("local-tls")) - cfg.BindPFlag("parent-udp", pflag.Lookup("parent-udp")) - cfg.BindPFlag("local-udp", pflag.Lookup("local-udp")) - cfg.BindPFlag("parent-tcp", pflag.Lookup("parent-tcp")) - cfg.BindPFlag("local-tcp", pflag.Lookup("local-tcp")) - cfg.BindPFlag("local-http", pflag.Lookup("local-http")) - cfg.BindPFlag("always", pflag.Lookup("always")) - cfg.BindPFlag("check-proxy-interval", pflag.Lookup("check-proxy-interval")) - cfg.BindPFlag("port", pflag.Lookup("port")) - cfg.BindPFlag("check-timeout", pflag.Lookup("check-timeout")) - cfg.BindPFlag("tcp-timeout", pflag.Lookup("tcp-timeout")) - cfg.BindPFlag("check-interval", pflag.Lookup("check-interval")) - cfg.BindPFlag("pool-size", pflag.Lookup("pool-size")) - cfg.BindPFlag("parent", pflag.Lookup("parent")) - cfg.BindPFlag("ip", pflag.Lookup("ip")) - cfg.BindPFlag("cert", pflag.Lookup("cert")) - cfg.BindPFlag("key", pflag.Lookup("key")) - cfg.BindPFlag("blocked", pflag.Lookup("blocked")) - cfg.BindPFlag("direct", pflag.Lookup("direct")) - cfg.BindPFlag("auth", pflag.Lookup("auth")) - cfg.BindPFlag("auth-file", pflag.Lookup("auth-file")) - - //version - if *version { - fmt.Printf("proxy v%s\n", APP_VERSION) - os.Exit(0) + if *certTLS != "" && *keyTLS != "" { + args.CertBytes, args.KeyBytes = tlsBytes(*certTLS, *keyTLS) } + httpArgs.Args = args + tcpArgs.Args = args + tlsArgs.Args = args + udpArgs.Args = args //keygen - if len(pflag.Args()) > 0 { - if pflag.Arg(0) == "keygen" { - keygen() + if len(os.Args) > 1 { + if os.Args[1] == "keygen" { + utils.Keygen() os.Exit(0) } } - - poster() - - if *configFile != "" { - cfg.SetConfigFile(*configFile) - } else { - cfg.SetConfigName("proxy") - cfg.AddConfigPath("/etc/proxy/") - cfg.AddConfigPath("$HOME/.proxy") - cfg.AddConfigPath(".proxy") - cfg.AddConfigPath(".") + //regist services and run service + serviceName := kingpin.MustParse(app.Parse(os.Args[1:])) + services.Regist("http", services.NewHTTP(), httpArgs) + services.Regist("tcp", services.NewTCP(), tcpArgs) + services.Regist("tls", services.NewTLS(), tlsArgs) + services.Regist("udp", services.NewUDP(), udpArgs) + service, err = services.Run(serviceName) + if err != nil { + log.Fatalf("run service [%s] fail, ERR:%s", service, err) } - - err = cfg.ReadInConfig() - file := cfg.ConfigFileUsed() - if err != nil && !strings.Contains(err.Error(), "Not") { - log.Fatalf("parse config fail, ERR:%s", err) - } else if file != "" { - log.Printf("use config file : %s", file) - } - err = nil return } @@ -121,3 +99,16 @@ func poster() { v%s`+" by snail , blog : http://www.host900.com/\n\n", APP_VERSION) } +func tlsBytes(cert, key string) (certBytes, keyBytes []byte) { + certBytes, err := ioutil.ReadFile(cert) + if err != nil { + log.Fatalf("err : %s", err) + return + } + keyBytes, err = ioutil.ReadFile(key) + if err != nil { + log.Fatalf("err : %s", err) + return + } + return +} diff --git a/main.go b/main.go index 6435e08..53a16be 100644 --- a/main.go +++ b/main.go @@ -2,290 +2,313 @@ package main import ( "fmt" - "io" - "io/ioutil" "log" - "net" - "os/exec" - "runtime/debug" - "time" + "os" + "os/signal" + "proxy/services" + "syscall" ) const APP_VERSION = "2.2" -var ( - checker Checker - certBytes []byte - keyBytes []byte - outPool ConnPool - basicAuth BasicAuth -) +// var ( +// checker Checker +// certBytes []byte +// keyBytes []byte +// outPool ConnPool +// basicAuth BasicAuth +// ) func init() { + + // return + // //Init + // err = Init() + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // isLocalHTTP := cfg.GetBool("local-http") + // isTLS := cfg.GetBool("local-tls") || cfg.GetBool("parent-tls") + // isTCP := isLocalHTTP || isTLS || cfg.GetBool("local-tcp") || cfg.GetBool("parent-tcp") + + // //InitTCP + // if isTCP { + // err = InitTCP() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitTLS + // if isTLS { + // err = InitTLS() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitUDP + // if cfg.GetBool("local-udp") || cfg.GetBool("parent-udp") { + // err = InitUDP() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitLocal + // err = InitLocal() + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitLocalTCP + // if cfg.GetBool("local-tcp") || cfg.GetBool("local-tls") || isLocalHTTP { + // err = InitLocalTCP() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitLocalTLS + // if cfg.GetBool("local-tls") { + // err = InitLocalTLS() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitLocalHTTP + // if isLocalHTTP { + // err = InitLocalHTTP() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitLocalUDP + // if cfg.GetBool("local-udp") { + // err = InitLocalUDP() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitParent + // if cfg.GetString("parent") != "" { + // err = InitParent() + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitParentTCP + // if cfg.GetBool("parent-tcp") || cfg.GetBool("parent-tls") { + // err = InitParentTCP() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitParentTLS + // if cfg.GetBool("parent-tls") { + // err = InitParentTLS() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // //InitParentUDP + // if cfg.GetBool("parent-udp") { + // err = InitParentUDP() + // } + // if err != nil { + // log.Fatalf("err : %s", err) + // } + // } +} +func main() { err := initConfig() if err != nil { log.Fatalf("err : %s", err) } - //Init - err = Init() - if err != nil { - log.Fatalf("err : %s", err) - } - isLocalHTTP := cfg.GetBool("local-http") - isTLS := cfg.GetBool("local-tls") || cfg.GetBool("parent-tls") - isTCP := isLocalHTTP || isTLS || cfg.GetBool("local-tcp") || cfg.GetBool("parent-tcp") - - //InitTCP - if isTCP { - err = InitTCP() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitTLS - if isTLS { - err = InitTLS() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitUDP - if cfg.GetBool("local-udp") || cfg.GetBool("parent-udp") { - err = InitUDP() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitLocal - err = InitLocal() - if err != nil { - log.Fatalf("err : %s", err) - } - //InitLocalTCP - if cfg.GetBool("local-tcp") || cfg.GetBool("local-tls") || isLocalHTTP { - err = InitLocalTCP() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitLocalTLS - if cfg.GetBool("local-tls") { - err = InitLocalTLS() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitLocalHTTP - if isLocalHTTP { - err = InitLocalHTTP() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitLocalUDP - if cfg.GetBool("local-udp") { - err = InitLocalUDP() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitParent - if cfg.GetString("parent") != "" { - err = InitParent() - if err != nil { - log.Fatalf("err : %s", err) - } - //InitParentTCP - if cfg.GetBool("parent-tcp") || cfg.GetBool("parent-tls") { - err = InitParentTCP() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitParentTLS - if cfg.GetBool("parent-tls") { - err = InitParentTLS() - } - if err != nil { - log.Fatalf("err : %s", err) - } - //InitParentUDP - if cfg.GetBool("parent-udp") { - err = InitParentUDP() - } - if err != nil { - log.Fatalf("err : %s", err) - } - } + Clean(&service.S) } -func Init() (err error) { - return -} -func InitTCP() (err error) { - return -} -func InitTLS() (err error) { - certBytes, err = ioutil.ReadFile(cfg.GetString("cert")) - if err != nil { - log.Printf("err : %s", err) - return - } - keyBytes, err = ioutil.ReadFile(cfg.GetString("key")) - if err != nil { - log.Printf("err : %s", err) - return - } - return -} -func InitUDP() (err error) { - return -} -func InitLocal() (err error) { - return -} -func InitLocalTCP() (err error) { - return -} -func InitLocalTLS() (err error) { - - return -} -func InitLocalHTTP() (err error) { - err = InitBasicAuth() - if err != nil { - return - } - return -} - -func InitLocalUDP() (err error) { - return -} -func InitParent() (err error) { - initOutPool(cfg.GetBool("parent-tls"), certBytes, keyBytes, cfg.GetString("parent"), cfg.GetInt("tcp-timeout"), cfg.GetInt("pool-size"), cfg.GetInt("pool-size")*2) - checker = NewChecker(cfg.GetInt("check-timeout"), int64(cfg.GetInt("check-interval")), cfg.GetString("blocked"), cfg.GetString("direct")) - log.Printf("use parent proxy : %s, udp : %v, tcp : %v, tls: %v", cfg.GetString("parent"), cfg.GetBool("parent-udp"), cfg.GetBool("parent-tcp"), cfg.GetBool("parent-tls")) - return -} -func InitParentTCP() (err error) { - return -} -func InitParentTLS() (err error) { - return -} -func InitParentUDP() (err error) { - return -} -func main() { - //catch panic error - defer func() { - e := recover() - if e != nil { - log.Printf("err : %s,\ntrace:%s", e, string(debug.Stack())) +func Clean(s *services.Service) { + //block main() + signalChan := make(chan os.Signal, 1) + cleanupDone := make(chan bool) + signal.Notify(signalChan, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT) + go func() { + for _ = range signalChan { + fmt.Println("\nReceived an interrupt, stopping services...") + (*s).Clean() + cleanupDone <- true } }() - - sc := NewServerChannel(cfg.GetString("ip"), cfg.GetInt("port")) - if cfg.GetBool("local-tls") { - LocalTLSServer(&sc) - } else if cfg.GetBool("local-tcp") { - LocalTCPServer(&sc) - } else if cfg.GetBool("local-udp") { - LocalUDPServer(&sc) - } - log.Printf("proxy on %s , udp: %v, tcp: %v, tls: %v ,http: %v", (*sc.Listener).Addr(), cfg.GetBool("local-udp"), cfg.GetBool("local-tcp"), cfg.GetBool("local-tls"), cfg.GetBool("local-http")) - - clean() + <-cleanupDone } -func CheckTCPDeocder(inConn *net.Conn) (useProxy bool, address string, req *HTTPRequest, err error) { - if cfg.GetBool("local-http") { - useProxy, req, err = HTTPProxyDecoder(inConn) - if err != nil { - return - } - address = req.Host - } else { - address = cfg.GetString("parent") - } - if address == "" { - useProxy = false - } else if cfg.GetBool("always") { - useProxy = true - } - return -} -func LocalTCPServer(sc *ServerChannel) { - (*sc).ListenTCP(func(inConn net.Conn) { - userProxy, address, req, err := CheckTCPDeocder(&inConn) - if err != nil { - if err != io.EOF { - log.Printf("tcp decode error , ERR:%s", err) - } - return - } - TCPOutBridge(&inConn, userProxy, address, req) - }) -} +// func Init() (err error) { +// return +// } +// func InitTCP() (err error) { +// return +// } +// func InitTLS() (err error) { +// certBytes, err = ioutil.ReadFile(cfg.GetString("cert")) +// if err != nil { +// log.Printf("err : %s", err) +// return +// } +// keyBytes, err = ioutil.ReadFile(cfg.GetString("key")) +// if err != nil { +// log.Printf("err : %s", err) +// return +// } +// return +// } +// func InitUDP() (err error) { +// return +// } +// func InitLocal() (err error) { +// return +// } +// func InitLocalTCP() (err error) { +// return +// } +// func InitLocalTLS() (err error) { -func LocalTLSServer(sc *ServerChannel) { - (*sc).ListenTls(certBytes, keyBytes, func(inConn net.Conn) { - userProxy, address, req, err := CheckTCPDeocder(&inConn) - if err != nil { - if err != io.EOF { - log.Printf("tls decode error , ERR:%s", err) - } - return - } - TCPOutBridge(&inConn, userProxy, address, req) - }) -} +// return +// } +// func InitLocalHTTP() (err error) { +// err = InitBasicAuth() +// if err != nil { +// return +// } +// return +// } -func LocalUDPServer(sc *ServerChannel) { - (*sc).ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) { +// func InitLocalUDP() (err error) { +// return +// } +// func InitParent() (err error) { +// initOutPool(cfg.GetBool("parent-tls"), certBytes, keyBytes, cfg.GetString("parent"), cfg.GetInt("tcp-timeout"), cfg.GetInt("pool-size"), cfg.GetInt("pool-size")*2) +// checker = NewChecker(cfg.GetInt("check-timeout"), int64(cfg.GetInt("check-interval")), cfg.GetString("blocked"), cfg.GetString("direct")) +// log.Printf("use parent proxy : %s, udp : %v, tcp : %v, tls: %v", cfg.GetString("parent"), cfg.GetBool("parent-udp"), cfg.GetBool("parent-tcp"), cfg.GetBool("parent-tls")) +// return +// } +// func InitParentTCP() (err error) { +// return +// } +// func InitParentTLS() (err error) { +// return +// } +// func InitParentUDP() (err error) { +// return +// } +// func main() { +// //catch panic error +// defer func() { +// e := recover() +// if e != nil { +// log.Printf("err : %s,\ntrace:%s", e, string(debug.Stack())) +// } +// }() +// return +// sc := NewServerChannel(cfg.GetString("ip"), cfg.GetInt("port")) +// if cfg.GetBool("local-tls") { +// LocalTLSServer(&sc) +// } else if cfg.GetBool("local-tcp") { +// LocalTCPServer(&sc) +// } else if cfg.GetBool("local-udp") { +// LocalUDPServer(&sc) +// } +// log.Printf("proxy on %s , udp: %v, tcp: %v, tls: %v ,http: %v", (*sc.Listener).Addr(), cfg.GetBool("local-udp"), cfg.GetBool("local-tcp"), cfg.GetBool("local-tls"), cfg.GetBool("local-http")) - }) -} +// clean() +// } -func TCPOutBridge(inConn *net.Conn, userProxy bool, address string, req *HTTPRequest) { - var outConn net.Conn - var _outConn interface{} - var err error - if userProxy { - _outConn, err = outPool.Get() - if err == nil { - outConn = _outConn.(net.Conn) - } - } else { - outConn, err = ConnectHost(address, cfg.GetInt("tcp-timeout")) - } - if err != nil { - log.Printf("connect to %s , err:%s", address, err) - closeConn(inConn) - return - } - inAddr := (*inConn).RemoteAddr().String() - outAddr := outConn.RemoteAddr().String() - log.Printf("%s use proxy %v", address, userProxy) +// func CheckTCPDeocder(inConn *net.Conn) (useProxy bool, address string, req *HTTPRequest, err error) { +// if cfg.GetBool("local-http") { +// useProxy, req, err = HTTPProxyDecoder(inConn) +// if err != nil { +// return +// } +// address = req.Host +// } else { +// address = cfg.GetString("parent") +// } +// if address == "" { +// useProxy = false +// } else if cfg.GetBool("always") { +// useProxy = true +// } +// return +// } +// func LocalTCPServer(sc *ServerChannel) { +// (*sc).ListenTCP(func(inConn net.Conn) { +// userProxy, address, req, err := CheckTCPDeocder(&inConn) +// if err != nil { +// if err != io.EOF { +// log.Printf("tcp decode error , ERR:%s", err) +// } +// return +// } +// TCPOutBridge(&inConn, userProxy, address, req) +// }) +// } - if req != nil { - if req.IsHTTPS() && !userProxy { - req.HTTPSReply() - } else { - outConn.Write(req.headBuf) - } - } +// func LocalTLSServer(sc *ServerChannel) { +// (*sc).ListenTls(certBytes, keyBytes, func(inConn net.Conn) { +// userProxy, address, req, err := CheckTCPDeocder(&inConn) +// if err != nil { +// if err != io.EOF { +// log.Printf("tls decode error , ERR:%s", err) +// } +// return +// } +// TCPOutBridge(&inConn, userProxy, address, req) +// }) +// } - IoBind(*inConn, outConn, func(err error) { - log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address) - closeConn(inConn) - closeConn(&outConn) - }, func(n int, d bool) {}, 0) - log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address) -} -func UDPOutBridge() { +// func LocalUDPServer(sc *ServerChannel) { +// (*sc).ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) { -} +// }) +// } + +// func TCPOutBridge(inConn *net.Conn, userProxy bool, address string, req *HTTPRequest) { +// var outConn net.Conn +// var _outConn interface{} +// var err error +// if userProxy { +// _outConn, err = outPool.Get() +// if err == nil { +// outConn = _outConn.(net.Conn) +// } +// } else { +// outConn, err = ConnectHost(address, cfg.GetInt("tcp-timeout")) +// } +// if err != nil { +// log.Printf("connect to %s , err:%s", address, err) +// closeConn(inConn) +// return +// } +// inAddr := (*inConn).RemoteAddr().String() +// outAddr := outConn.RemoteAddr().String() +// log.Printf("%s use proxy %v", address, userProxy) + +// if req != nil { +// if req.IsHTTPS() && !userProxy { +// req.HTTPSReply() +// } else { +// outConn.Write(req.headBuf) +// } +// } + +// IoBind(*inConn, outConn, func(err error) { +// log.Printf("conn %s - %s [%s] released", inAddr, outAddr, address) +// closeConn(inConn) +// closeConn(&outConn) +// }, func(n int, d bool) {}, 0) +// log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address) +// } +// func UDPOutBridge() { + +// } // func DoUDP() { // if cfg.GetBool("local-udp") { @@ -583,26 +606,3 @@ func UDPOutBridge() { // }, func(n int, d bool) {}, 0) // log.Printf("conn %s - %s [%s] connected", inAddr, outAddr, address) // } -func closeConn(conn *net.Conn) { - if *conn != nil { - (*conn).SetDeadline(time.Now().Add(time.Millisecond)) - (*conn).Close() - } -} -func keygen() (err error) { - cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048") - out, err := cmd.CombinedOutput() - if err != nil { - log.Printf("err:%s", err) - return - } - fmt.Println(string(out)) - cmd = exec.Command("sh", "-c", `openssl req -new -key proxy.key -x509 -days 3650 -out proxy.crt -subj /C=CN/ST=BJ/O="Localhost Ltd"/CN=proxy`) - out, err = cmd.CombinedOutput() - if err != nil { - log.Printf("err:%s", err) - return - } - fmt.Println(string(out)) - return -} diff --git a/services/args.go b/services/args.go new file mode 100644 index 0000000..d33b119 --- /dev/null +++ b/services/args.go @@ -0,0 +1,46 @@ +package services + +// tcp := app.Command("tcp", "proxy on tcp mode") +// t := tcp.Flag("tcp-timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("2000").Int() + +const ( + TYPE_TCP = "tcp" + TYPE_UDP = "udp" + TYPE_HTTP = "http" + TYPE_TLS = "tls" +) + +type Args struct { + Local *string + Parent *string + CertBytes []byte + KeyBytes []byte + PoolSize *int + CheckParentInterval *int +} +type TCPArgs struct { + Args + Timeout *int + ParentType *string +} +type TLSArgs struct { + Args + Timeout *int + ParentType *string +} +type HTTPArgs struct { + Args + Always *bool + HTTPTimeout *int + Timeout *int + Interval *int + Blocked *string + Direct *string + AuthFile *string + Auth *[]string + ParentType *string + LocalType *string +} +type UDPArgs struct { + Args +} diff --git a/services/http.go b/services/http.go new file mode 100644 index 0000000..332b9b3 --- /dev/null +++ b/services/http.go @@ -0,0 +1,208 @@ +package services + +import ( + "fmt" + "log" + "net" + "proxy/utils" + "runtime/debug" + "strconv" +) + +type HTTP struct { + outPool utils.OutPool + cfg HTTPArgs + checker utils.Checker + basicAuth utils.BasicAuth +} + +func NewHTTP() Service { + return &HTTP{ + outPool: utils.OutPool{}, + cfg: HTTPArgs{}, + checker: utils.Checker{}, + basicAuth: utils.BasicAuth{}, + } +} +func (s *HTTP) InitService() { + s.InitBasicAuth() + s.checker = utils.NewChecker(*s.cfg.HTTPTimeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct) +} +func (s *HTTP) StopService() { + if s.outPool.Pool != nil { + s.outPool.Pool.ReleaseAll() + } +} +func (s *HTTP) Start(args interface{}) (err error) { + s.cfg = args.(HTTPArgs) + if *s.cfg.Parent != "" { + log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent) + s.InitOutConnPool() + } + + s.InitService() + + host, port, _ := net.SplitHostPort(*s.cfg.Local) + p, _ := strconv.Atoi(port) + sc := utils.NewServerChannel(host, p) + if *s.cfg.LocalType == TYPE_TCP { + err = sc.ListenTCP(s.callback) + } else { + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.callback) + } + if err != nil { + return + } + log.Printf("%s http(s) proxy on %s", *s.cfg.LocalType, (*sc.Listener).Addr()) + return +} + +func (s *HTTP) Clean() { + s.StopService() +} +func (s *HTTP) callback(inConn net.Conn) { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("http(s) conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) + } + }() + req, err := utils.NewHTTPRequest(&inConn, 4096, s.IsBasicAuth(), &s.basicAuth) + if err != nil { + log.Printf("decoder error , form %s, ERR:%s", err, inConn.RemoteAddr()) + utils.CloseConn(&inConn) + return + } + address := req.Host + useProxy := true + if *s.cfg.Parent == "" { + useProxy = false + } else if *s.cfg.Always { + useProxy = true + } else { + useProxy, _, _ = s.checker.IsBlocked(req.Host) + } + log.Printf("use proxy : %v, %s", useProxy, address) + //os.Exit(0) + err = s.OutToTCP(useProxy, address, &inConn, &req) + if err != nil { + if *s.cfg.Parent == "" { + log.Printf("connect to %s fail, ERR:%s", address, err) + } else { + log.Printf("connect to %s parent %s fail", *s.cfg.ParentType, *s.cfg.Parent) + } + utils.CloseConn(&inConn) + } + }() +} +func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *utils.HTTPRequest) (err error) { + inAddr := (*inConn).RemoteAddr().String() + inLocalAddr := (*inConn).LocalAddr().String() + //防止死循环 + if s.IsDeadLoop(inLocalAddr, req.Host) { + utils.CloseConn(inConn) + err = fmt.Errorf("dead loop detected , %s", req.Host) + return + } + var outConn net.Conn + var _outConn interface{} + if useProxy { + _outConn, err = s.outPool.Pool.Get() + if err == nil { + outConn = _outConn.(net.Conn) + } + } else { + outConn, err = utils.ConnectHost(address, *s.cfg.Timeout) + } + if err != nil { + log.Printf("connect to %s , err:%s", *s.cfg.Parent, err) + utils.CloseConn(inConn) + return + } + + outAddr := outConn.RemoteAddr().String() + outLocalAddr := outConn.LocalAddr().String() + + if req.IsHTTPS() && !useProxy { + req.HTTPSReply() + } else { + outConn.Write(req.HeadBuf) + } + utils.IoBind((*inConn), outConn, func(err error) { + log.Printf("conn %s - %s - %s -%s released [%s]", inAddr, inLocalAddr, outLocalAddr, outAddr, req.Host) + utils.CloseConn(inConn) + utils.CloseConn(&outConn) + }, func(n int, d bool) {}, 0) + log.Printf("conn %s - %s - %s - %s connected [%s]", inAddr, inLocalAddr, outLocalAddr, outAddr, req.Host) + return +} +func (s *HTTP) OutToUDP(inConn *net.Conn) (err error) { + return +} +func (s *HTTP) InitOutConnPool() { + if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP { + //dur int, isTLS bool, certBytes, keyBytes []byte, + //parent string, timeout int, InitialCap int, MaxCap int + s.outPool = utils.NewOutPool( + *s.cfg.CheckParentInterval, + *s.cfg.ParentType == TYPE_TLS, + s.cfg.CertBytes, s.cfg.KeyBytes, + *s.cfg.Parent, + *s.cfg.Timeout, + *s.cfg.PoolSize, + *s.cfg.PoolSize*2, + ) + } +} +func (s *HTTP) InitBasicAuth() (err error) { + s.basicAuth = utils.NewBasicAuth() + if *s.cfg.AuthFile != "" { + var n = 0 + n, err = s.basicAuth.AddFromFile(*s.cfg.AuthFile) + if err != nil { + err = fmt.Errorf("auth-file ERR:%s", err) + return + } + log.Printf("auth data added from file %d , total:%d", n, s.basicAuth.Total()) + } + if len(*s.cfg.Auth) > 0 { + n := s.basicAuth.Add(*s.cfg.Auth) + log.Printf("auth data added %d, total:%d", n, s.basicAuth.Total()) + } + return +} +func (s *HTTP) IsBasicAuth() bool { + return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 +} +func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool { + inIP, inPort, err := net.SplitHostPort(inLocalAddr) + if err != nil { + return false + } + outDomain, outPort, err := net.SplitHostPort(host) + if err != nil { + return false + } + if inPort == outPort { + var outIPs []net.IP + outIPs, err = net.LookupIP(outDomain) + if err == nil { + for _, ip := range outIPs { + if ip.String() == inIP { + return true + } + } + } + interfaceIPs, err := utils.GetAllInterfaceAddr() + if err == nil { + for _, localIP := range interfaceIPs { + for _, outIP := range outIPs { + if localIP.Equal(outIP) { + return true + } + } + } + } + } + return false +} diff --git a/services/service.go b/services/service.go new file mode 100644 index 0000000..575763a --- /dev/null +++ b/services/service.go @@ -0,0 +1,48 @@ +package services + +import ( + "fmt" + "log" + "runtime/debug" +) + +type Service interface { + Start(args interface{}) (err error) + Clean() +} +type ServiceItem struct { + S Service + Args interface{} + Name string +} + +var servicesMap = map[string]ServiceItem{} + +func Regist(name string, s Service, args interface{}) { + servicesMap[name] = ServiceItem{ + S: s, + Args: args, + Name: name, + } +} +func Run(name string) (service ServiceItem, err error) { + service, ok := servicesMap[name] + if ok { + go func() { + defer func() { + err := recover() + if err != nil { + log.Fatalf("%s servcie crashed, ERR: %s\ntrace:%s", name, err, string(debug.Stack())) + } + }() + err := service.S.Start(service.Args) + if err != nil { + log.Fatalf("%s servcie fail, ERR: %s", name, err) + } + }() + } + if !ok { + err = fmt.Errorf("service %s not found", name) + } + return +} diff --git a/services/tcp.go b/services/tcp.go new file mode 100644 index 0000000..0c27a0e --- /dev/null +++ b/services/tcp.go @@ -0,0 +1,121 @@ +package services + +import ( + "fmt" + "log" + "net" + "proxy/utils" + "runtime/debug" + + "strconv" +) + +type TCP struct { + outPool utils.OutPool + cfg TCPArgs +} + +func NewTCP() Service { + return &TCP{ + outPool: utils.OutPool{}, + cfg: TCPArgs{}, + } +} +func (s *TCP) InitService() { + s.InitOutConnPool() +} +func (s *TCP) StopService() { + if s.outPool.Pool != nil { + s.outPool.Pool.ReleaseAll() + } +} +func (s *TCP) Start(args interface{}) (err error) { + s.cfg = args.(TCPArgs) + if *s.cfg.Parent != "" { + log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent) + } else { + log.Fatalf("parent required for tcp", *s.cfg.Local) + } + + s.InitService() + + host, port, _ := net.SplitHostPort(*s.cfg.Local) + p, _ := strconv.Atoi(port) + sc := utils.NewServerChannel(host, p) + err = sc.ListenTCP(func(inConn net.Conn) { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("tcp conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) + } + }() + var err error + switch *s.cfg.ParentType { + case TYPE_TCP: + fallthrough + case TYPE_TLS: + err = s.OutToTCP(&inConn) + case TYPE_UDP: + err = s.OutToUDP(&inConn) + default: + err = fmt.Errorf("unkown parent type %s", *s.cfg.ParentType) + } + if err != nil { + log.Printf("connect to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err) + utils.CloseConn(&inConn) + } + }() + }) + if err != nil { + return + } + log.Printf("tcp proxy on %s", (*sc.Listener).Addr()) + return +} + +func (s *TCP) Clean() { + s.StopService() +} + +func (s *TCP) OutToTCP(inConn *net.Conn) (err error) { + var outConn net.Conn + var _outConn interface{} + _outConn, err = s.outPool.Pool.Get() + if err == nil { + outConn = _outConn.(net.Conn) + } + if err != nil { + log.Printf("connect to %s , err:%s", *s.cfg.Parent, err) + utils.CloseConn(inConn) + return + } + inAddr := (*inConn).RemoteAddr().String() + inLocalAddr := (*inConn).LocalAddr().String() + outAddr := outConn.RemoteAddr().String() + outLocalAddr := outConn.LocalAddr().String() + utils.IoBind((*inConn), outConn, func(err error) { + log.Printf("conn %s - %s - %s -%s released", inAddr, inLocalAddr, outLocalAddr, outAddr) + utils.CloseConn(inConn) + utils.CloseConn(&outConn) + }, func(n int, d bool) {}, 0) + log.Printf("conn %s - %s - %s -%s connected", inAddr, inLocalAddr, outLocalAddr, outAddr) + return +} +func (s *TCP) OutToUDP(inConn *net.Conn) (err error) { + return +} +func (s *TCP) InitOutConnPool() { + if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP { + //dur int, isTLS bool, certBytes, keyBytes []byte, + //parent string, timeout int, InitialCap int, MaxCap int + s.outPool = utils.NewOutPool( + *s.cfg.CheckParentInterval, + *s.cfg.ParentType == TYPE_TLS, + s.cfg.CertBytes, s.cfg.KeyBytes, + *s.cfg.Parent, + *s.cfg.Timeout, + *s.cfg.PoolSize, + *s.cfg.PoolSize*2, + ) + } +} diff --git a/services/tls.go b/services/tls.go new file mode 100644 index 0000000..2c09dc0 --- /dev/null +++ b/services/tls.go @@ -0,0 +1,117 @@ +package services + +import ( + "fmt" + "log" + "net" + "proxy/utils" + "runtime/debug" + "strconv" +) + +type TLS struct { + outPool utils.OutPool + cfg TLSArgs +} + +func NewTLS() Service { + return &TLS{} +} +func (s *TLS) InitService() { + s.InitOutConnPool() +} +func (s *TLS) StopService() { + if s.outPool.Pool != nil { + s.outPool.Pool.ReleaseAll() + } +} +func (s *TLS) Start(args interface{}) (err error) { + s.cfg = args.(TLSArgs) + if *s.cfg.Parent != "" { + log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent) + } else { + log.Fatalf("parent required for tls") + } + + s.InitService() + + host, port, _ := net.SplitHostPort(*s.cfg.Local) + p, _ := strconv.Atoi(port) + sc := utils.NewServerChannel(host, p) + err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, func(inConn net.Conn) { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("tls conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack())) + } + }() + var err error + switch *s.cfg.ParentType { + case TYPE_TCP: + fallthrough + case TYPE_TLS: + err = s.OutToTCP(&inConn) + case TYPE_UDP: + err = s.OutToUDP(&inConn) + default: + err = fmt.Errorf("unkown parent type %s", *s.cfg.ParentType) + } + if err != nil { + log.Printf("connect to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err) + utils.CloseConn(&inConn) + } + }() + }) + if err != nil { + return + } + log.Printf("tls proxy on %s", (*sc.Listener).Addr()) + return +} + +func (s *TLS) Clean() { + s.StopService() +} + +func (s *TLS) OutToTCP(inConn *net.Conn) (err error) { + var outConn net.Conn + var _outConn interface{} + _outConn, err = s.outPool.Pool.Get() + if err == nil { + outConn = _outConn.(net.Conn) + } + if err != nil { + log.Printf("connect to %s , err:%s", *s.cfg.Parent, err) + utils.CloseConn(inConn) + return + } + inAddr := (*inConn).RemoteAddr().String() + inLocalAddr := (*inConn).LocalAddr().String() + outAddr := outConn.RemoteAddr().String() + outLocalAddr := outConn.LocalAddr().String() + utils.IoBind((*inConn), outConn, func(err error) { + log.Printf("conn %s - %s - %s -%s released", inAddr, inLocalAddr, outLocalAddr, outAddr) + utils.CloseConn(inConn) + utils.CloseConn(&outConn) + }, func(n int, d bool) {}, 0) + log.Printf("conn %s - %s - %s -%s connected", inAddr, inLocalAddr, outLocalAddr, outAddr) + return +} +func (s *TLS) OutToUDP(inConn *net.Conn) (err error) { + return +} +func (s *TLS) InitOutConnPool() { + if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP { + //dur int, isTLS bool, certBytes, keyBytes []byte, + //parent string, timeout int, InitialCap int, MaxCap int + s.outPool = utils.NewOutPool( + *s.cfg.CheckParentInterval, + *s.cfg.ParentType == TYPE_TLS, + s.cfg.CertBytes, s.cfg.KeyBytes, + *s.cfg.Parent, + *s.cfg.Timeout, + *s.cfg.PoolSize, + *s.cfg.PoolSize*2, + ) + } +} diff --git a/services/udp.go b/services/udp.go new file mode 100644 index 0000000..9f9aaa0 --- /dev/null +++ b/services/udp.go @@ -0,0 +1,19 @@ +package services + +import ( + "log" +) + +type UDP struct { +} + +func NewUDP() Service { + return &UDP{} +} +func (s *UDP) Start(args interface{}) (err error) { + log.Printf("called") + return +} +func (s *UDP) Clean() { + +} diff --git a/functions.go b/utils/functions.go similarity index 66% rename from functions.go rename to utils/functions.go index 7c20a6d..9eb8ed1 100755 --- a/functions.go +++ b/utils/functions.go @@ -1,4 +1,4 @@ -package main +package utils import ( "crypto/tls" @@ -10,11 +10,11 @@ import ( "net" "net/http" "os" - "os/signal" + "os/exec" + "runtime/debug" "strconv" "strings" - "syscall" "time" ) @@ -187,118 +187,73 @@ func HTTPGet(URL string, timeout int) (err error) { return } -func initOutPool(isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) { - var err error - outPool, err = NewConnPool(poolConfig{ - IsActive: func(conn interface{}) bool { return true }, - Release: func(conn interface{}) { - if conn != nil { - conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond)) - conn.(net.Conn).Close() - // log.Println("conn released") - } - }, - InitialCap: InitialCap, - MaxCap: MaxCap, - Factory: func() (conn interface{}, err error) { - conn, err = getConn(isTLS, certBytes, keyBytes, address, timeout) - return - }, - }) +func CloseConn(conn *net.Conn) { + if *conn != nil { + (*conn).SetDeadline(time.Now().Add(time.Millisecond)) + (*conn).Close() + } +} +func Keygen() (err error) { + cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048") + out, err := cmd.CombinedOutput() if err != nil { - log.Fatalf("init conn pool fail ,%s", err) - } else { - log.Printf("init conn pool success") - initPoolDeamon(isTLS, certBytes, keyBytes, address, timeout) - } -} -func getConn(isTLS bool, certBytes, keyBytes []byte, address string, timeout int) (conn interface{}, err error) { - if isTLS { - var _conn tls.Conn - _conn, err = TlsConnectHost(address, timeout, certBytes, keyBytes) - if err == nil { - conn = net.Conn(&_conn) - } - } else { - conn, err = ConnectHost(address, timeout) - } - return -} -func initPoolDeamon(isTLS bool, certBytes, keyBytes []byte, address string, timeout int) { - go func() { - dur := cfg.GetInt("check-proxy-interval") - if dur <= 0 { - return - } - log.Printf("pool deamon started") - for { - time.Sleep(time.Second * time.Duration(dur)) - conn, err := getConn(isTLS, certBytes, keyBytes, address, timeout) - if err != nil { - log.Printf("pool deamon err %s , release pool", err) - outPool.ReleaseAll() - } else { - conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond)) - conn.(net.Conn).Close() - } - } - }() -} -func IsBasicAuth() bool { - return cfg.GetString("auth-file") != "" || len(cfg.GetStringSlice("auth")) > 0 -} -func InitBasicAuth() (err error) { - basicAuth = NewBasicAuth() - if cfg.GetString("auth-file") != "" { - n, err := basicAuth.AddFromFile(cfg.GetString("auth-file")) - if err != nil { - return fmt.Errorf("auth-file ERR:%s", err) - } - log.Printf("auth data added from file %d , total:%d", n, basicAuth.Total()) - } - if len(cfg.GetStringSlice("auth")) > 0 { - n := basicAuth.Add(cfg.GetStringSlice("auth")) - log.Printf("auth data added %d, total:%d", n, basicAuth.Total()) - } - return -} -func clean() { - //block main() - signalChan := make(chan os.Signal, 1) - cleanupDone := make(chan bool) - signal.Notify(signalChan, - os.Interrupt, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT) - go func() { - for _ = range signalChan { - if outPool != nil { - fmt.Println("\nReceived an interrupt, stopping services...") - outPool.ReleaseAll() - //time.Sleep(time.Second * 10) - // fmt.Println("done") - } - cleanupDone <- true - } - }() - <-cleanupDone -} -func HTTPProxyDecoder(inConn *net.Conn) (useProxy bool, request *HTTPRequest, err error) { - var req HTTPRequest - req, err = NewHTTPRequest(inConn, 4096) - if err != nil { - //log.Printf("NewHTTPRequest ERR:%s", err) + log.Printf("err:%s", err) return } - useProxy = false - if checker.data != nil { - useProxy, _, _ = checker.IsBlocked(req.Host) + fmt.Println(string(out)) + cmd = exec.Command("sh", "-c", `openssl req -new -key proxy.key -x509 -days 3650 -out proxy.crt -subj /C=CN/ST=BJ/O="Localhost Ltd"/CN=proxy`) + out, err = cmd.CombinedOutput() + if err != nil { + log.Printf("err:%s", err) + return } - request = &req + fmt.Println(string(out)) return } +func GetAllInterfaceAddr() ([]net.IP, error) { + + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + addresses := []net.IP{} + for _, iface := range ifaces { + + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + // if iface.Flags&net.FlagLoopback != 0 { + // continue // loopback interface + // } + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + // if ip == nil || ip.IsLoopback() { + // continue + // } + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + addresses = append(addresses, ip) + } + } + if len(addresses) == 0 { + return nil, fmt.Errorf("no address Found, net.InterfaceAddrs: %v", addresses) + } + //only need first + return addresses, nil +} // type sockaddr struct { // family uint16 diff --git a/io-limiter.go b/utils/io-limiter.go similarity index 99% rename from io-limiter.go rename to utils/io-limiter.go index 37c8417..5162d15 100644 --- a/io-limiter.go +++ b/utils/io-limiter.go @@ -1,4 +1,4 @@ -package main +package utils import ( "context" diff --git a/map.go b/utils/map.go similarity index 99% rename from map.go rename to utils/map.go index 70277e5..8ec82cf 100644 --- a/map.go +++ b/utils/map.go @@ -1,4 +1,4 @@ -package main +package utils import ( "encoding/json" diff --git a/pool.go b/utils/pool.go similarity index 99% rename from pool.go rename to utils/pool.go index 21394b9..ae30f6f 100755 --- a/pool.go +++ b/utils/pool.go @@ -1,4 +1,4 @@ -package main +package utils import ( "log" diff --git a/serve-channel.go b/utils/serve-channel.go similarity index 96% rename from serve-channel.go rename to utils/serve-channel.go index dd85f71..50753ba 100644 --- a/serve-channel.go +++ b/utils/serve-channel.go @@ -1,4 +1,4 @@ -package main +package utils import ( "fmt" @@ -60,7 +60,8 @@ func (sc *ServerChannel) ListenTls(certBytes, keyBytes []byte, fn func(conn net. } func (sc *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) { - l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", sc.ip, sc.port)) + var l net.Listener + l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sc.ip, sc.port)) if err == nil { sc.Listener = &l go func() { diff --git a/structs.go b/utils/structs.go similarity index 75% rename from structs.go rename to utils/structs.go index dd298c4..101a10f 100644 --- a/structs.go +++ b/utils/structs.go @@ -1,7 +1,8 @@ -package main +package utils import ( "bytes" + "crypto/tls" "encoding/base64" "fmt" "io" @@ -223,15 +224,17 @@ func (ba *BasicAuth) Total() (n int) { } type HTTPRequest struct { - headBuf []byte - conn *net.Conn - Host string - Method string - URL string - hostOrURL string + HeadBuf []byte + conn *net.Conn + Host string + Method string + URL string + hostOrURL string + isBasicAuth bool + basicAuth *BasicAuth } -func NewHTTPRequest(inConn *net.Conn, bufSize int) (req HTTPRequest, err error) { +func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth) (req HTTPRequest, err error) { buf := make([]byte, bufSize) len := 0 req = HTTPRequest{ @@ -242,23 +245,25 @@ func NewHTTPRequest(inConn *net.Conn, bufSize int) (req HTTPRequest, err error) if err != io.EOF { err = fmt.Errorf("http decoder read err:%s", err) } - closeConn(inConn) + CloseConn(inConn) return } - req.headBuf = buf[:len] - index := bytes.IndexByte(req.headBuf, '\n') + req.HeadBuf = buf[:len] + index := bytes.IndexByte(req.HeadBuf, '\n') if index == -1 { - err = fmt.Errorf("http decoder data line err:%s", string(req.headBuf)[:50]) - closeConn(inConn) + err = fmt.Errorf("http decoder data line err:%s", string(req.HeadBuf)[:50]) + CloseConn(inConn) return } - fmt.Sscanf(string(req.headBuf[:index]), "%s%s", &req.Method, &req.hostOrURL) + fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL) if req.Method == "" || req.hostOrURL == "" { - err = fmt.Errorf("http decoder data err:%s", string(req.headBuf)[:50]) - closeConn(inConn) + err = fmt.Errorf("http decoder data err:%s", string(req.HeadBuf)[:50]) + CloseConn(inConn) return } req.Method = strings.ToUpper(req.Method) + req.isBasicAuth = isBasicAuth + req.basicAuth = basicAuth log.Printf("%s:%s", req.Method, req.hostOrURL) if req.IsHTTPS() { @@ -269,7 +274,7 @@ func NewHTTPRequest(inConn *net.Conn, bufSize int) (req HTTPRequest, err error) return } func (req *HTTPRequest) HTTP() (err error) { - if IsBasicAuth() { + if req.isBasicAuth { err = req.BasicAuth() if err != nil { return @@ -303,27 +308,27 @@ func (req *HTTPRequest) BasicAuth() (err error) { authorization, err := req.getHeader("Authorization") if err != nil { fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized") - closeConn(req.conn) + CloseConn(req.conn) return } //log.Printf("Authorization:%s", authorization) basic := strings.Fields(authorization) if len(basic) != 2 { err = fmt.Errorf("authorization data error,ERR:%s", authorization) - closeConn(req.conn) + CloseConn(req.conn) return } user, err := base64.StdEncoding.DecodeString(basic[1]) if err != nil { err = fmt.Errorf("authorization data parse error,ERR:%s", err) - closeConn(req.conn) + CloseConn(req.conn) return } - authOk := basicAuth.Check(string(user)) + authOk := (*req.basicAuth).Check(string(user)) //log.Printf("auth %s,%v", string(user), authOk) if !authOk { fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized") - closeConn(req.conn) + CloseConn(req.conn) err = fmt.Errorf("basic auth fail") return } @@ -342,7 +347,7 @@ func (req *HTTPRequest) getHTTPURL() (URL string, err error) { } func (req *HTTPRequest) getHeader(key string) (val string, err error) { key = strings.ToUpper(key) - lines := strings.Split(string(req.headBuf), "\r\n") + lines := strings.Split(string(req.HeadBuf), "\r\n") for _, line := range lines { line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2) if len(line) == 2 { @@ -357,6 +362,7 @@ func (req *HTTPRequest) getHeader(key string) (val string, err error) { err = fmt.Errorf("can not find HOST header") return } + func (req *HTTPRequest) addPortIfNot() (newHost string) { //newHost = req.Host port := "80" @@ -371,14 +377,83 @@ func (req *HTTPRequest) addPortIfNot() (newHost string) { return } -// func (req *HTTPRequest) fixHost(host string) string { -// if !strings.HasPrefix(host, "[") && len(strings.Split(host, ":")) > 2 { -// if strings.HasSuffix(host, ":80") { -// return fmt.Sprintf("[%s]:80", host[:strings.LastIndex(host, ":")]) -// } -// if strings.HasSuffix(host, ":443") { -// return fmt.Sprintf("[%s]:443", host[:strings.LastIndex(host, ":")]) -// } -// } -// return host -// } +type OutPool struct { + Pool ConnPool + dur int + isTLS bool + certBytes []byte + keyBytes []byte + address string + timeout int +} + +func NewOutPool(dur int, isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) { + op = OutPool{ + dur: dur, + isTLS: isTLS, + certBytes: certBytes, + keyBytes: keyBytes, + address: address, + timeout: timeout, + } + var err error + op.Pool, err = NewConnPool(poolConfig{ + IsActive: func(conn interface{}) bool { return true }, + Release: func(conn interface{}) { + if conn != nil { + conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond)) + conn.(net.Conn).Close() + // log.Println("conn released") + } + }, + InitialCap: InitialCap, + MaxCap: MaxCap, + Factory: func() (conn interface{}, err error) { + conn, err = op.getConn() + return + }, + }) + if err != nil { + log.Fatalf("init conn pool fail ,%s", err) + } else { + if InitialCap > 0 { + log.Printf("init conn pool success") + op.initPoolDeamon() + } else { + log.Printf("conn pool closed") + } + } + return +} +func (op *OutPool) getConn() (conn interface{}, err error) { + if op.isTLS { + var _conn tls.Conn + _conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes) + if err == nil { + conn = net.Conn(&_conn) + } + } else { + conn, err = ConnectHost(op.address, op.timeout) + } + return +} + +func (op *OutPool) initPoolDeamon() { + go func() { + if op.dur <= 0 { + return + } + log.Printf("pool deamon started") + for { + time.Sleep(time.Second * time.Duration(op.dur)) + conn, err := op.getConn() + if err != nil { + log.Printf("pool deamon err %s , release pool", err) + op.Pool.ReleaseAll() + } else { + conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond)) + conn.(net.Conn).Close() + } + } + }() +}