diff --git a/CHANGELOG b/CHANGELOG
index 71c2141..78104f4 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,4 +1,29 @@
proxy更新日志
+
+v6.0 企业版开源啦
+本次更新主要是把企业版开源,把企业版代码合并到现在的开源goproxy当中,继续遵循GPLv3,免费开源,
+之所以直接跳过5.x,用6.0版本号是为了与现有开源版本做一个明显的区分,下面功能主要来自企业版.
+企业版代码结构更合理,核心与开源版本有很大区别,与此同时企业版有一个core开发库,基于此库可以
+几行代码实现自己高度定制化的各种网络安全传输服务器和客户端和代理服务器与客户端.与此同时企
+业版独创了TCPS协议,处于应用层和TCP层之间,可以为应用提供透明化的安全传输功能,另外还对dst协
+议进行了一些改造,集成到goproxy中,实现了tcp over udp功能,那么除了kcp之外现在还可以选择dst
+作为底层的tcp over udp的传输.下一步加入插件机制,定制功能可以使用插件方式开发了,热插拔,
+不需要修改goproxy二进制,可以插件so或者dylib注入.
+
+1.预编译的二进制增加了armv8支持.
+2.预编译的mipsle和mips二进制增加了softfloat支持.
+3.优化连接HTTP(s)上级代理的CONNECT指令,附带更多的信息.
+4.重构了内网穿透的UDP功能,性能大幅度提升,可以愉快的与异地基友玩依赖UDP的局域网游戏了.
+5.重构了UDP端口映射,性能大幅度提升.
+6.HTTP(S)\SOCKS5\SPS代理支持上级负载均衡,可以同时指定多个上级.
+7.SPS支持HTTP(S)\SOCKS5\SS协议相互转换.
+8.HTTP(S)\SOCKS5\SPS代理支持限速.
+9.HTTP(S)\SOCKS5代理支持指定出口IP.
+10.SOCKS5代理支持级联认证.
+11.修复了tclient可能意外退出的bug.
+12.优化了错误捕获,防止意外crash.
+13.优化了停止服务,释放内存.
+
v5.4
1.优化了获取本地IP信息导致CPU过高的问题.
2.所有服务都增加了--nolog参数,可以关闭日志输出,节省CPU.
diff --git a/README_ZH.md b/README_ZH.md
index 80668a5..c370d1f 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -1,5 +1,5 @@
-Proxy是golang实现的高性能http,https,websocket,tcp,udp,socks5代理服务器,支持正向代理、反向代理、透明代理、内网穿透、TCP/UDP端口映射、SSH中转、TLS加密传输、协议转换、防污染DNS代理。
+Proxy是golang实现的高性能http,https,websocket,tcp,udp,socks5,ss代理服务器,支持正向代理、反向代理、透明代理、内网穿透、TCP/UDP端口映射、SSH中转、TLS加密传输、协议转换、防污染DNS代理。
[点击下载](https://github.com/snail007/goproxy/releases) 官方QQ交流群:189618940
@@ -33,10 +33,16 @@ Proxy是golang实现的高性能http,https,websocket,tcp,udp,socks5代理服务
- 集成外部API,HTTP(S),SOCKS5代理认证功能可以与外部HTTP API集成,可以方便的通过外部系统控制代理用户.
- 反向代理,支持直接把域名解析到proxy监听的ip,然后proxy就会帮你代理访问需要访问的HTTP(S)网站.
- 透明HTTP(S)代理,配合iptables,在网关直接把出去的80,443方向的流量转发到proxy,就能实现无感知的智能路由器代理.
-- 协议转换,可以把已经存在的HTTP(S)或SOCKS5代理转换为一个端口同时支持HTTP(S)和SOCKS5代理,转换后的SOCKS5代理不支持UDP功能,同时支持强大的级联认证功能。
+- 协议转换,可以把已经存在的HTTP(S)或SOCKS5或SS代理转换为一个端口同时支持HTTP(S)和SOCKS5和SS代理,转换后的SOCKS5和SS代理如果上级是SOCKS5代理,那么支持UDP功能,同时支持强大的级联认证功能。
- 自定义底层加密传输,http(s)\sps\socks代理在tcp之上可以通过tls标准加密以及kcp协议加密tcp数据,除此之外还支持在tls和kcp之后进行自定义加密,也就是说自定义加密和tls|kcp是可以联合使用的,内部采用AES256加密,使用的时候只需要自己定义一个密码即可。
- 底层压缩高效传输,http(s)\sps\socks代理在tcp之上可以通过自定义加密和tls标准加密以及kcp协议加密tcp数据,在加密之后还可以对数据进行压缩,也就是说压缩功能和自定义加密和tls|kcp是可以联合使用的。
- 安全的DNS代理,可以通过本地的proxy提供的DNS代理服务器与上级代理加密通讯实现安全防污染的DNS查询。
+- 负载均衡,高可用,HTTP(S)\SOCKS5\SPS代理支持上级负载均衡和高可用,多个上级重复-P参数即可.
+- 指定出口IP,HTTP(S)\SOCKS5\SPS代理支持客户端用入口IP连接过来的,就用入口IP作为出口IP访问目标网站的功能。如果入口IP是内网IP,出口IP不会使用入口IP
+- 支持限速,HTTP(S)\SOCKS5\SPS代理支持限速.
+- SOCKS5代理支持级联认证.
+- 证书参数使用base64数据,默认情况下-C,-K参数是crt证书和key文件的路径,如果是base64://开头,那么就认为后面的数据是base64编码的,会解码后使用.
+
### Why need these?
- 当由于某某原因,我们不能访问我们在其它地方的服务,我们可以通过多个相连的proxy节点建立起一个安全的隧道访问我们的服务.
@@ -70,8 +76,9 @@ Proxy是golang实现的高性能http,https,websocket,tcp,udp,socks5代理服务
- [安全建议](#安全建议)
### 手册目录
+- [负载均衡和高可用](#负载均衡和高可用)
- [1. HTTP代理](#1http代理)
- - [1.1 普通HTTP代理](#11普通http代理)
+ - [1.1 普通HTTP代理](#11普通一级http代理)
- [1.2 普通二级HTTP代理](#12普通二级http代理)
- [1.3 HTTP二级代理(加密)](#13http二级代理加密)
- [1.4 HTTP三级代理(加密)](#14http三级代理加密)
@@ -86,7 +93,11 @@ Proxy是golang实现的高性能http,https,websocket,tcp,udp,socks5代理服务
- [1.11 自定义DNS](#111-自定义dns)
- [1.12 自定义加密](#112-自定义加密)
- [1.13 压缩传输](#113-压缩传输)
- - [1.14 查看帮助](#114-查看帮助)
+ - [1.14 负载均衡](#114-负载均衡)
+ - [1.15 限速](#115-限速)
+ - [1.16 指定出口IP](#116-指定出口ip)
+ - [1.17 证书参数使用base64数据](#117-证书参数使用base64数据)
+ - [1.18 查看帮助](#118-查看帮助)
- [2. TCP代理(端口映射)](#2tcp代理)
- [2.1 普通一级TCP代理](#21普通一级tcp代理)
- [2.2 普通二级TCP代理](#22普通二级tcp代理)
@@ -126,18 +137,27 @@ Proxy是golang实现的高性能http,https,websocket,tcp,udp,socks5代理服务
- [5.9 自定义DNS](#59自定义dns)
- [5.10 自定义加密](#510-自定义加密)
- [5.11 压缩传输](#511-压缩传输)
- - [5.12 查看帮助](#512查看帮助)
+ - [5.12 负载均衡](#512-负载均衡)
+ - [5.13 限速](#513-限速)
+ - [5.14 指定出口IP](#514-指定出口ip)
+ - [5.15 级联认证](#515-级联认证)
+ - [5.16 证书参数使用base64数据](#516-证书参数使用base64数据)
+ - [5.17 查看帮助](#517查看帮助)
- [6. 代理协议转换](#6代理协议转换)
- [6.1 功能介绍](#61-功能介绍)
- - [6.2 HTTP(S)转HTTP(S)+SOCKS5](#62-https转httpssocks5)
- - [6.3 SOCKS5转HTTP(S)+SOCKS5](#63-socks5转httpssocks5)
- - [6.4 链式连接](#64-链式连接)
- - [6.5 监听多个端口](#65-监听多个端口)
- - [6.6 认证功能](#66-认证功能)
- - [6.7 自定义加密](#67-自定义加密)
- - [6.8 压缩传输](#68-压缩传输)
- - [6.9 禁用协议](#69-禁用协议)
- - [6.10 查看帮助](#610-查看帮助)
+ - [6.2 HTTP(S)转HTTP(S)+SOCKS5+SS](#62-https转httpssocks5ss)
+ - [6.3 SOCKS5转HTTP(S)+SOCKS5+SS](#63-socks5转httpssocks5ss)
+ - [6.4 SS转HTTP(S)+SOCKS5+SS](#64-ss转httpssocks5ss)
+ - [6.5 链式连接](#65-链式连接)
+ - [6.6 监听多个端口](#66-监听多个端口)
+ - [6.7 认证功能](#67-认证功能)
+ - [6.8 自定义加密](#68-自定义加密)
+ - [6.9 压缩传输](#69-压缩传输)
+ - [6.10 禁用协议](#610-禁用协议)
+ - [6.11 限速](#611-限速)
+ - [6.12 指定出口IP](#612-指定出口ip)
+ - [6.13 证书参数使用base64数据](#613-证书参数使用base64数据)
+ - [6.14 查看帮助](#614-查看帮助)
- [7. KCP配置](#7kcp配置)
- [7.1 配置介绍](#71-配置介绍)
- [7.2 详细配置](#72-详细配置)
@@ -152,7 +172,7 @@ Proxy是golang实现的高性能http,https,websocket,tcp,udp,socks5代理服务
```shell
curl -L https://raw.githubusercontent.com/snail007/goproxy/master/install_auto.sh | bash
```
-安装完成,配置目录是/etc/proxy,更详细的使用方法参考下面的进一步了解.
+安装完成,配置目录是/etc/proxy,更详细的使用方法请参考上面的手册目录,进一步了解你想要使用的功能.
如果安装失败或者你的vps不是linux64位系统,请按照下面的半自动步骤安装:
#### 手动安装
@@ -161,7 +181,7 @@ curl -L https://raw.githubusercontent.com/snail007/goproxy/master/install_auto.s
下载地址:https://github.com/snail007/goproxy/releases
```shell
cd /root/proxy/
-wget https://github.com/snail007/goproxy/releases/download/v5.4/proxy-linux-amd64.tar.gz
+wget https://github.com/snail007/goproxy/releases/download/v6.0/proxy-linux-amd64.tar.gz
```
#### **2.下载自动安装脚本**
```shell
@@ -260,6 +280,32 @@ proxy会fork子进程,然后监控子进程,如果子进程异常退出,5秒后
假设你的vps外网ip是23.23.23.23,下面命令通过-g参数设置23.23.23.23
`./proxy http -g "23.23.23.23"`
+### **负载均衡和高可用**
+
+HTTP(S)\SOCKS5\SPS代理支持上级负载均衡和高可用,多个上级重复-P参数即可.
+
+负载均衡策略支持5种,可以通过`--lb-method`参数指定:
+
+roundrobin 轮流使用
+
+leastconn 使用最小连接数的
+
+leasttime 使用连接时间最小的
+
+hash 使用根据客户端地址计算出一个固定上级
+
+weight 根据每个上级的权重和连接数情况,选择出一个上级
+
+提示:
+
+负载均衡检查时间间隔可以通过`--lb-retrytime`设置,单位毫秒
+
+负载均衡连接超时时间可以通过`--lb-timeout`设置,单位毫秒
+
+如果负载均衡策略是权重(weight),-P格式为:2.2.2.2:3880@1,1就是权重,大于0的整数.
+
+如果负载均衡策略是hash,默认是根据客户端地址选择上级,可以通过开关`--lb-hashtarget`使用访问的目标地址选择上级.
+
### **1.HTTP代理**
#### **1.1.普通一级HTTP代理**

@@ -461,7 +507,43 @@ proxy的http(s)代理在tcp之上可以通过tls标准加密以及kcp协议加
`proxy http -T tcp -P 3.3.3.3:8888 -M -t tcp -p :8080`
这样通过本地代理8080访问网站的时候就是通过与上级压缩传输访问目标网站.
-#### **1.14 查看帮助**
+### **1.14 负载均衡**
+
+HTTP(S)代理支持上级负载均衡,多个上级重复-P参数即可.
+
+`proxy http --lb-method=hash -T tcp -P 1.1.1.1:33080 -P 2.1.1.1:33080 -P 3.1.1.1:33080`
+
+#### **1.14.1 设置重试间隔和超时时间**
+
+`proxy http --lb-method=leastconn --lb-retrytime 300 --lb-timeout 300 -T tcp -P 1.1.1.1:33080 -P 2.1.1.1:33080 -P 3.1.1.1:33080 -t tcp -p :33080`
+
+#### **1.14.2 设置权重**
+
+`proxy http --lb-method=weight -T tcp -P 1.1.1.1:33080@1 -P 2.1.1.1:33080@2 -P 3.1.1.1:33080@1 -t tcp -p :33080`
+
+#### **1.14.3 使用目标地址选择上级**
+
+`proxy http --lb-hashtarget --lb-method=leasttime -T tcp -P 1.1.1.1:33080 -P 2.1.1.1:33080 -P 3.1.1.1:33080 -t tcp -p :33080`
+
+### **1.15 限速**
+
+限速100K,通过`-l`参数即可指定,比如:100K 1.5M . 0意味着无限制.
+
+`proxy http -t tcp -p 2.2.2.2:33080 -l 100K`
+
+### **1.16 指定出口IP**
+
+`--bind-listen`参数,就可以开启客户端用入口IP连接过来的,就用入口IP作为出口IP访问目标网站的功能。如果入口IP是内网IP,出口IP不会使用入口IP。
+
+`proxy http -t tcp -p 2.2.2.2:33080 --bind-listen`
+
+### **1.17 证书参数使用base64数据**
+
+默认情况下-C,-K参数是crt证书和key文件的路径,
+
+如果是base64://开头,那么就认为后面的数据是base64编码的,会解码后使用.
+
+#### **1.18 查看帮助**
`./proxy help http`
### **2.TCP代理**
@@ -890,41 +972,98 @@ proxy的socks代理在tcp之上可以通过自定义加密和tls标准加密以
`proxy socks -T tcp -P 3.3.3.3:8888 -M -t tcp -p :8080`
这样通过本地代理8080访问网站的时候就是通过与上级压缩传输访问目标网站.
-#### **5.12.查看帮助**
+
+#### **5.12 负载均衡**
+
+SOCKS代理支持上级负载均衡,多个上级重复-P参数即可.
+
+`proxy socks --lb-method=hash -T tcp -P 1.1.1.1:33080 -P 2.1.1.1:33080 -P 3.1.1.1:33080 -p :33080 -t tcp`
+
+#### **5.12.1 设置重试间隔和超时时间**
+
+`proxy socks --lb-method=leastconn --lb-retrytime 300 --lb-timeout 300 -T tcp -P 1.1.1.1:33080 -P 2.1.1.1:33080 -P 3.1.1.1:33080 -p :33080 -t tcp`
+
+#### **5.12.2 设置权重**
+
+`proxy socks --lb-method=weight -T tcp -P 1.1.1.1:33080@1 -P 2.1.1.1:33080@2 -P 3.1.1.1:33080@1 -p :33080 -t tcp`
+
+#### **5.12.3 使用目标地址选择上级**
+
+`proxy socks --lb-hashtarget --lb-method=leasttime -T tcp -P 1.1.1.1:33080 -P 2.1.1.1:33080 -P 3.1.1.1:33080 -p :33080 -t tcp`
+
+#### **5.13 限速**
+
+限速100K,通过`-l`参数即可指定,比如:100K 1.5M . 0意味着无限制.
+
+`proxy socks -t tcp -p 2.2.2.2:33080 -l 100K`
+
+#### **5.14 指定出口IP**
+
+`--bind-listen`参数,就可以开启客户端用入口IP连接过来的,就用入口IP作为出口IP访问目标网站的功能。如果入口IP是内网IP,出口IP不会使用入口IP。
+
+`proxy socks -t tcp -p 2.2.2.2:33080 --bind-listen`
+
+#### **5.15 级联认证**
+
+SOCKS5支持级联认证,-A可以设置上级认证信息.
+
+上级:
+
+`proxy socks -t tcp -p 2.2.2.2:33080 -a user:pass`
+
+本地:
+
+`proxy socks -T tcp -P 2.2.2.2:33080 -A user:pass -t tcp -p :33080`
+
+#### **5.16 证书参数使用base64数据**
+
+默认情况下-C,-K参数是crt证书和key文件的路径,
+
+如果是base64://开头,那么就认为后面的数据是base64编码的,会解码后使用.
+
+
+#### **5.17.查看帮助**
`./proxy help socks`
### **6.代理协议转换**
#### **6.1 功能介绍**
-代理协议转换使用的是sps子命令(socks+https的缩写),sps本身不提供代理功能,只是接受代理请求"转换并转发"给已经存在的http(s)代理或者socks5代理;sps可以把已经存在的http(s)代理或者socks5代理转换为一个端口同时支持http(s)和socks5代理,而且http(s)代理支持正向代理和反向代理(SNI),转换后的SOCKS5代理,当上级是SOCKS5时仍然支持UDP功能;另外对于已经存在的http(s)代理或者socks5代理,支持tls、tcp、kcp三种模式,支持链式连接,也就是可以多个sps结点层级连接构建加密通道。
+代理协议转换使用的是sps子命令,sps本身不提供代理功能,只是接受代理请求"转换并转发"给已经存在的http(s)代理或者socks5代理或者ss代理;sps可以把已经存在的http(s)代理或者socks5代理或ss代理转换为一个端口同时支持http(s)和socks5和ss的代理,而且http(s)代理支持正向代理和反向代理(SNI),转换后的SOCKS5代理,当上级是SOCKS5或者SS时仍然支持UDP功能;另外对于已经存在的http(s)代理或者socks5代理,支持tls、tcp、kcp三种模式,支持链式连接,也就是可以多个sps结点层级连接构建加密通道。
-#### **6.2 HTTP(S)转HTTP(S)+SOCKS5**
-假设已经存在一个普通的http(s)代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5的普通代理,转换后的本地端口为18080。
+#### **6.2 HTTP(S)转HTTP(S)+SOCKS5+SS**
+假设已经存在一个普通的http(s)代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5和ss的普通代理,转换后的本地端口为18080,ss加密方式:aes-192-cfb,ss密码:pass。
命令如下:
-`./proxy sps -S http -T tcp -P 127.0.0.1:8080 -t tcp -p :18080`
+`./proxy sps -S http -T tcp -P 127.0.0.1:8080 -t tcp -p :18080 -h aes-192-cfb -j pass`
-假设已经存在一个tls的http(s)代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5的普通代理,转换后的本地端口为18080,tls需要证书文件。
+假设已经存在一个tls的http(s)代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5和ss的普通代理,转换后的本地端口为18080,tls需要证书文件,ss加密方式:aes-192-cfb,ss密码:pass。
命令如下:
-`./proxy sps -S http -T tls -P 127.0.0.1:8080 -t tcp -p :18080 -C proxy.crt -K proxy.key`
+`./proxy sps -S http -T tls -P 127.0.0.1:8080 -t tcp -p :18080 -C proxy.crt -K proxy.key -h aes-192-cfb -j pass`
-假设已经存在一个kcp的http(s)代理(密码是:demo123):127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5的普通代理,转换后的本地端口为18080。
+假设已经存在一个kcp的http(s)代理(密码是:demo123):127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5和ss的普通代理,转换后的本地端口为18080,ss加密方式:aes-192-cfb,ss密码:pass。
命令如下:
-`./proxy sps -S http -T kcp -P 127.0.0.1:8080 -t tcp -p :18080 --kcp-key demo123`
+`./proxy sps -S http -T kcp -P 127.0.0.1:8080 -t tcp -p :18080 --kcp-key demo123 -h aes-192-cfb -j pass`
-#### **6.3 SOCKS5转HTTP(S)+SOCKS5**
-假设已经存在一个普通的socks5代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5的普通代理,转换后的本地端口为18080。
+#### **6.3 SOCKS5转HTTP(S)+SOCKS5+SS**
+假设已经存在一个普通的socks5代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5和ss的普通代理,转换后的本地端口为18080,ss加密方式:aes-192-cfb,ss密码:pass。
命令如下:
-`./proxy sps -S socks -T tcp -P 127.0.0.1:8080 -t tcp -p :18080`
+`./proxy sps -S socks -T tcp -P 127.0.0.1:8080 -t tcp -p :18080 -h aes-192-cfb -j pass`
-假设已经存在一个tls的socks5代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5的普通代理,转换后的本地端口为18080,tls需要证书文件。
+假设已经存在一个tls的socks5代理:127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5和ss的普通代理,转换后的本地端口为18080,tls需要证书文件,ss加密方式:aes-192-cfb,ss密码:pass。
命令如下:
-`./proxy sps -S socks -T tls -P 127.0.0.1:8080 -t tcp -p :18080 -C proxy.crt -K proxy.key`
+`./proxy sps -S socks -T tls -P 127.0.0.1:8080 -t tcp -p :18080 -C proxy.crt -K proxy.key -h aes-192-cfb -j pass`
-假设已经存在一个kcp的socks5代理(密码是:demo123):127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5的普通代理,转换后的本地端口为18080。
+假设已经存在一个kcp的socks5代理(密码是:demo123):127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5和ss的普通代理,转换后的本地端口为18080,ss加密方式:aes-192-cfb,ss密码:pass。
命令如下:
-`./proxy sps -S socks -T kcp -P 127.0.0.1:8080 -t tcp -p :18080 --kcp-key demo123`
+`./proxy sps -S socks -T kcp -P 127.0.0.1:8080 -t tcp -p :18080 --kcp-key demo123 -h aes-192-cfb -j pass`
-#### **6.4 链式连接**
+#### **6.4 SS转HTTP(S)+SOCKS5+SS**
+SPS上级和本地支持ss协议,上级可以是SPS或者标准的ss服务.
+SPS本地默认提供HTTP(S)\SOCKS5\SPS三种代理,当上级是SOCKS5时转换后的SOCKS5和SS支持UDP功能.
+假设已经存在一个普通的SS或者SPS代理(开启了ss,加密方式:aes-256-cfb,密码:demo):127.0.0.1:8080,现在我们把它转为同时支持http(s)和socks5和ss的普通代理,转换后的本地端口为18080,转换后的ss加密方式:aes-192-cfb,ss密码:pass。
+命令如下:
+`./proxy sps -S ss -H aes-256-cfb -J pass -T tcp -P 127.0.0.1:8080 -t tcp -p :18080 -h aes-192-cfb -j pass`.
+
+#### **6.5 链式连接**

上面提过多个sps结点可以层级连接构建加密通道,假设有如下vps和家里的pc电脑。
vps01:2.2.2.2
@@ -944,11 +1083,11 @@ vps02:3.3.3.3
完成。
-#### **6.5 监听多个端口**
+#### **6.6 监听多个端口**
一般情况下监听一个端口就可以,不过如果作为反向代理需要同时监听80和443两个端口,那么-p参数是支持的,
格式是:`-p 0.0.0.0:80,0.0.0.0:443`,多个绑定用逗号分隔即可。
-#### **6.6 认证功能**
+#### **6.7 认证功能**
sps支持http(s)\socks5代理认证,可以级联认证,有四个重要的信息:
1:用户发送认证信息`user-auth`。
2:设置的本地认证信息`local-auth`。
@@ -991,7 +1130,7 @@ target:如果客户端是http(s)代理请求,这里代表的是请求的完整ur
如果没有-A参数,连接上级不使用认证.
-#### **6.7 自定义加密**
+#### **6.8 自定义加密**
proxy的sps代理在tcp之上可以通过tls标准加密以及kcp协议加密tcp数据,除此之外还支持在tls和kcp之后进行
自定义加密,也就是说自定义加密和tls|kcp是可以联合使用的,内部采用AES256加密,使用的时候只需要自己定义
一个密码即可,加密分为两个部分,一部分是本地(-z)是否加密解密,一部分是与上级(-Z)传输是否加密解密.
@@ -1019,7 +1158,7 @@ proxy的sps代理在tcp之上可以通过tls标准加密以及kcp协议加密tcp
`proxy sps -T tcp -P 3.3.3.3:8888 -Z other_password -t tcp -p :8080`
这样通过本地代理8080访问网站的时候就是通过与上级加密传输访问目标网站.
-#### **6.8 压缩传输**
+#### **6.9 压缩传输**
proxy的sps代理在tcp之上可以通过自定义加密和tls标准加密以及kcp协议加密tcp数据,在自定义加密之前还可以
对数据进行压缩,也就是说压缩功能和自定义加密和tls|kcp是可以联合使用的,压缩分为两个部分,
一部分是本地(-m)是否压缩传输,一部分是与上级(-M)传输是否压缩.
@@ -1045,8 +1184,7 @@ proxy的sps代理在tcp之上可以通过自定义加密和tls标准加密以及
`proxy sps -T tcp -P 3.3.3.3:8888 -M -t tcp -p :8080`
这样通过本地代理8080访问网站的时候就是通过与上级压缩传输访问目标网站.
-
-#### **6.9 禁用协议**
+#### **6.10 禁用协议**
SPS默认情况下一个端口支持http(s)和socks5两种代理协议,我们可以通过参数禁用某个协议
比如:
1.禁用HTTP(S)代理功能只保留SOCKS5代理功能,参数:`--disable-http`.
@@ -1055,7 +1193,31 @@ SPS默认情况下一个端口支持http(s)和socks5两种代理协议,我们可
1.禁用SOCKS5代理功能只保留HTTP(S)代理功能,参数:`--disable-socks`.
`proxy sps -T tcp -P 3.3.3.3:8888 -M -t tcp -p :8080 --disable-http`
-#### **6.10 查看帮助**
+#### **6.11 限速**
+
+假设存在SOCKS5上级:
+
+`proxy socks -p 2.2.2.2:33080 -z password -t tcp`
+
+sps下级,限速100K
+
+`proxy sps -S socks -P 2.2.2.2:33080 -T tcp -Z password -l 100K -t tcp -p :33080`
+
+通过`-l`参数即可指定,比如:100K 1.5M . 0意味着无限制.
+
+#### **6.12 指定出口IP**
+
+`--bind-listen`参数,就可以开启客户端用入口IP连接过来的,就用入口IP作为出口IP访问目标网站的功能。如果入口IP是内网IP,出口IP不会使用入口IP。
+
+`proxy sps -S socks -P 2.2.2.2:33080 -T tcp -Z password -l 100K -t tcp --bind-listen -p :33080`
+
+#### **6.13 证书参数使用base64数据**
+
+默认情况下-C,-K参数是crt证书和key文件的路径,
+
+如果是base64://开头,那么就认为后面的数据是base64编码的,会解码后使用.
+
+#### **6.14 查看帮助**
`./proxy help sps`
### **7.KCP配置**
diff --git a/VERSION b/VERSION
new file mode 100644
index 0000000..5049538
--- /dev/null
+++ b/VERSION
@@ -0,0 +1 @@
+6.0
\ No newline at end of file
diff --git a/config.go b/config.go
index 0e9da16..c97ebc7 100755
--- a/config.go
+++ b/config.go
@@ -10,23 +10,23 @@ import (
"os/exec"
"path"
"path/filepath"
+ "runtime/debug"
"runtime/pprof"
"time"
sdk "github.com/snail007/goproxy/sdk/android-ios"
- "github.com/snail007/goproxy/services"
- "github.com/snail007/goproxy/services/kcpcfg"
-
+ services "github.com/snail007/goproxy/services"
httpx "github.com/snail007/goproxy/services/http"
+ "github.com/snail007/goproxy/services/kcpcfg"
keygenx "github.com/snail007/goproxy/services/keygen"
mux "github.com/snail007/goproxy/services/mux"
socksx "github.com/snail007/goproxy/services/socks"
spsx "github.com/snail007/goproxy/services/sps"
tcpx "github.com/snail007/goproxy/services/tcp"
- tunnel "github.com/snail007/goproxy/services/tunnel"
+ tunnelx "github.com/snail007/goproxy/services/tunnel"
udpx "github.com/snail007/goproxy/services/udp"
-
kcp "github.com/xtaci/kcp-go"
+
"golang.org/x/crypto/pbkdf2"
kingpin "gopkg.in/alecthomas/kingpin.v2"
)
@@ -43,9 +43,9 @@ func initConfig() (err error) {
//define args
tcpArgs := tcpx.TCPArgs{}
httpArgs := httpx.HTTPArgs{}
- tunnelServerArgs := tunnel.TunnelServerArgs{}
- tunnelClientArgs := tunnel.TunnelClientArgs{}
- tunnelBridgeArgs := tunnel.TunnelBridgeArgs{}
+ tunnelServerArgs := tunnelx.TunnelServerArgs{}
+ tunnelClientArgs := tunnelx.TunnelClientArgs{}
+ tunnelBridgeArgs := tunnelx.TunnelBridgeArgs{}
muxServerArgs := mux.MuxServerArgs{}
muxClientArgs := mux.MuxClientArgs{}
muxBridgeArgs := mux.MuxBridgeArgs{}
@@ -55,11 +55,10 @@ func initConfig() (err error) {
dnsArgs := sdk.DNSArgs{}
keygenArgs := keygenx.KeygenArgs{}
kcpArgs := kcpcfg.KCPConfigArgs{}
-
//build srvice args
app = kingpin.New("proxy", "happy with proxy")
app.Author("snail").Version(APP_VERSION)
- debug := app.Flag("debug", "debug log output").Default("false").Bool()
+ isDebug := app.Flag("debug", "debug log output").Default("false").Bool()
daemon := app.Flag("daemon", "run proxy in background").Default("false").Bool()
forever := app.Flag("forever", "run proxy in forever,fail and retry").Default("false").Bool()
logfile := app.Flag("log", "log file path").Default("").String()
@@ -84,7 +83,7 @@ func initConfig() (err error) {
//########http#########
http := app.Command("http", "proxy on http mode")
- httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
+ httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').Strings()
httpArgs.CaCertFile = http.Flag("ca", "ca cert file for tls").Default("").String()
httpArgs.CertFile = http.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
httpArgs.KeyFile = http.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
@@ -115,7 +114,14 @@ func initConfig() (err error) {
httpArgs.ParentKey = http.Flag("parent-key", "the password for auto encrypt/decrypt parent connection data").Short('Z').Default("").String()
httpArgs.LocalCompress = http.Flag("local-compress", "auto compress/decompress data on local connection").Short('m').Default("false").Bool()
httpArgs.ParentCompress = http.Flag("parent-compress", "auto compress/decompress data on parent connection").Short('M').Default("false").Bool()
-
+ httpArgs.LoadBalanceMethod = http.Flag("lb-method", "load balance method when use multiple parent,can be ").Default("hash").Enum("roundrobin", "weight", "leastconn", "leasttime", "hash")
+ httpArgs.LoadBalanceTimeout = http.Flag("lb-timeout", "tcp milliseconds timeout of connecting to parent").Default("500").Int()
+ httpArgs.LoadBalanceRetryTime = http.Flag("lb-retrytime", "sleep time milliseconds after checking").Default("1000").Int()
+ httpArgs.LoadBalanceHashTarget = http.Flag("lb-hashtarget", "use target address to choose parent for LB").Default("false").Bool()
+ httpArgs.LoadBalanceOnlyHA = http.Flag("lb-onlyha", "use only `high availability mode` to choose parent for LB").Default("false").Bool()
+ httpArgs.RateLimit = http.Flag("rate-limit", "rate limit (bytes/second) of each connection, such as: 100K 1.5M . 0 means no limitation").Short('l').Default("0").String()
+ httpArgs.BindListen = http.Flag("bind-listen", "using listener binding IP when connect to target").Short('B').Default("false").Bool()
+ httpArgs.Debug = isDebug
//########tcp#########
tcp := app.Command("tcp", "proxy on tcp mode")
tcpArgs.Parent = tcp.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
@@ -124,6 +130,7 @@ func initConfig() (err error) {
tcpArgs.Timeout = tcp.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Short('e').Default("2000").Int()
tcpArgs.ParentType = tcp.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "udp", "kcp")
tcpArgs.LocalType = tcp.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")
+ tcpArgs.CheckParentInterval = tcp.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int()
tcpArgs.Local = tcp.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
tcpArgs.Jumper = tcp.Flag("jumper", "https or socks5 proxies used when connecting to parent, only worked of -T is tls or tcp, format is https://username:password@host:port https://host:port or socks5://username:password@host:port socks5://host:port").Short('J').Default("").String()
@@ -200,7 +207,7 @@ func initConfig() (err error) {
//########ssh#########
socks := app.Command("socks", "proxy on ssh mode")
- socksArgs.Parent = socks.Flag("parent", "parent ssh address, such as: \"23.32.32.19:22\"").Default("").Short('P').String()
+ socksArgs.Parent = socks.Flag("parent", "parent ssh address, such as: \"23.32.32.19:22\"").Default("").Short('P').Strings()
socksArgs.ParentType = socks.Flag("parent-type", "parent protocol type ").Default("tcp").Short('T').Enum("tls", "tcp", "kcp", "ssh")
socksArgs.LocalType = socks.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")
socksArgs.Local = socks.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
@@ -210,7 +217,7 @@ func initConfig() (err error) {
socksArgs.SSHUser = socks.Flag("ssh-user", "user for ssh").Short('u').Default("").String()
socksArgs.SSHKeyFile = socks.Flag("ssh-key", "private key file for ssh").Short('S').Default("").String()
socksArgs.SSHKeyFileSalt = socks.Flag("ssh-keysalt", "salt of ssh private key").Short('s').Default("").String()
- socksArgs.SSHPassword = socks.Flag("ssh-password", "password for ssh").Short('A').Default("").String()
+ socksArgs.SSHPassword = socks.Flag("ssh-password", "password for ssh").Short('D').Default("").String()
socksArgs.Always = socks.Flag("always", "always use parent proxy").Default("false").Bool()
socksArgs.Timeout = socks.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("5000").Int()
socksArgs.Interval = socks.Flag("interval", "check domain if blocked every interval seconds").Default("10").Int()
@@ -223,16 +230,25 @@ func initConfig() (err error) {
socksArgs.AuthURLTimeout = socks.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int()
socksArgs.AuthURLOkCode = socks.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int()
socksArgs.AuthURLRetry = socks.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("0").Int()
+ socksArgs.ParentAuth = socks.Flag("parent-auth", "parent socks auth username and password, such as: -A user1:pass1").Short('A').String()
socksArgs.DNSAddress = socks.Flag("dns-address", "if set this, proxy will use this dns for resolve doamin").Short('q').Default("").String()
socksArgs.DNSTTL = socks.Flag("dns-ttl", "caching seconds of dns query result").Short('e').Default("300").Int()
socksArgs.LocalKey = socks.Flag("local-key", "the password for auto encrypt/decrypt local connection data").Short('z').Default("").String()
socksArgs.ParentKey = socks.Flag("parent-key", "the password for auto encrypt/decrypt parent connection data").Short('Z').Default("").String()
socksArgs.LocalCompress = socks.Flag("local-compress", "auto compress/decompress data on local connection").Short('m').Default("false").Bool()
socksArgs.ParentCompress = socks.Flag("parent-compress", "auto compress/decompress data on parent connection").Short('M').Default("false").Bool()
+ socksArgs.LoadBalanceMethod = socks.Flag("lb-method", "load balance method when use multiple parent,can be ").Default("hash").Enum("roundrobin", "weight", "leastconn", "leasttime", "hash")
+ socksArgs.LoadBalanceTimeout = socks.Flag("lb-timeout", "tcp milliseconds timeout of connecting to parent").Default("500").Int()
+ socksArgs.LoadBalanceRetryTime = socks.Flag("lb-retrytime", "sleep time milliseconds after checking").Default("1000").Int()
+ socksArgs.LoadBalanceHashTarget = socks.Flag("lb-hashtarget", "use target address to choose parent for LB").Default("false").Bool()
+ socksArgs.LoadBalanceOnlyHA = socks.Flag("lb-onlyha", "use only `high availability mode` to choose parent for LB").Default("false").Bool()
+ socksArgs.RateLimit = socks.Flag("rate-limit", "rate limit (bytes/second) of each connection, such as: 100K 1.5M . 0 means no limitation").Short('l').Default("0").String()
+ socksArgs.BindListen = socks.Flag("bind-listen", "using listener binding IP when connect to target").Short('B').Default("false").Bool()
+ socksArgs.Debug = isDebug
//########socks+http(s)#########
sps := app.Command("sps", "proxy on socks+http(s) mode")
- spsArgs.Parent = sps.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
+ spsArgs.Parent = sps.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').Strings()
spsArgs.CertFile = sps.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
spsArgs.KeyFile = sps.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
spsArgs.CaCertFile = sps.Flag("ca", "ca cert file for tls").Default("").String()
@@ -240,7 +256,7 @@ func initConfig() (err error) {
spsArgs.ParentType = sps.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "kcp")
spsArgs.LocalType = sps.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")
spsArgs.Local = sps.Flag("local", "local ip:port to listen,multiple address use comma split,such as: 0.0.0.0:80,0.0.0.0:443").Short('p').Default(":33080").String()
- spsArgs.ParentServiceType = sps.Flag("parent-service-type", "parent service type ").Short('S').Enum("http", "socks")
+ spsArgs.ParentServiceType = sps.Flag("parent-service-type", "parent service type ").Short('S').Enum("http", "socks", "ss")
spsArgs.DNSAddress = sps.Flag("dns-address", "if set this, proxy will use this dns for resolve doamin").Short('q').Default("").String()
spsArgs.DNSTTL = sps.Flag("dns-ttl", "caching seconds of dns query result").Short('e').Default("300").Int()
spsArgs.AuthFile = sps.Flag("auth-file", "http basic auth file,\"username:password\" each line in file").Short('F').String()
@@ -255,8 +271,20 @@ func initConfig() (err error) {
spsArgs.ParentKey = sps.Flag("parent-key", "the password for auto encrypt/decrypt parent connection data").Short('Z').Default("").String()
spsArgs.LocalCompress = sps.Flag("local-compress", "auto compress/decompress data on local connection").Short('m').Default("false").Bool()
spsArgs.ParentCompress = sps.Flag("parent-compress", "auto compress/decompress data on parent connection").Short('M').Default("false").Bool()
+ spsArgs.SSMethod = sps.Flag("ss-method", "the following methods are supported: aes-128-cfb, aes-192-cfb, aes-256-cfb, bf-cfb, cast5-cfb, des-cfb, rc4-md5, rc4-md5-6, chacha20, salsa20, rc4, table, des-cfb, chacha20-ietf; if you use ss client , \"-t tcp\" is required").Short('h').Default("aes-256-cfb").String()
+ spsArgs.SSKey = sps.Flag("ss-key", "if you use ss client , \"-t tcp\" is required").Short('j').Default("sspassword").String()
+ spsArgs.ParentSSMethod = sps.Flag("parent-ss-method", "the following methods are supported: aes-128-cfb, aes-192-cfb, aes-256-cfb, bf-cfb, cast5-cfb, des-cfb, rc4-md5, rc4-md5-6, chacha20, salsa20, rc4, table, des-cfb, chacha20-ietf; if you use ss server as parent, \"-T tcp\" is required").Short('H').Default("aes-256-cfb").String()
+ spsArgs.ParentSSKey = sps.Flag("parent-ss-key", "if you use ss server as parent, \"-T tcp\" is required").Short('J').Default("sspassword").String()
spsArgs.DisableHTTP = sps.Flag("disable-http", "disable http(s) proxy").Default("false").Bool()
spsArgs.DisableSocks5 = sps.Flag("disable-socks", "disable socks proxy").Default("false").Bool()
+ spsArgs.DisableSS = sps.Flag("disable-ss", "disable ss proxy").Default("false").Bool()
+ spsArgs.LoadBalanceMethod = sps.Flag("lb-method", "load balance method when use multiple parent,can be ").Default("hash").Enum("roundrobin", "weight", "leastconn", "leasttime", "hash")
+ spsArgs.LoadBalanceTimeout = sps.Flag("lb-timeout", "tcp milliseconds timeout of connecting to parent").Default("500").Int()
+ spsArgs.LoadBalanceRetryTime = sps.Flag("lb-retrytime", "sleep time milliseconds after checking").Default("1000").Int()
+ spsArgs.LoadBalanceHashTarget = sps.Flag("lb-hashtarget", "use target address to choose parent for LB").Default("false").Bool()
+ spsArgs.LoadBalanceOnlyHA = sps.Flag("lb-onlyha", "use only `high availability mode` to choose parent for LB").Default("false").Bool()
+ spsArgs.RateLimit = sps.Flag("rate-limit", "rate limit (bytes/second) of each connection, such as: 100K 1.5M . 0 means no limitation").Short('l').Default("0").String()
+ spsArgs.Debug = isDebug
//########dns#########
dns := app.Command("dns", "proxy on dns server mode")
@@ -287,8 +315,6 @@ func initConfig() (err error) {
//parse args
serviceName := kingpin.MustParse(app.Parse(os.Args[1:]))
- isDebug = *debug
-
//set kcp config
switch *kcpArgs.Mode {
@@ -345,7 +371,7 @@ func initConfig() (err error) {
log := logger.New(os.Stderr, "", logger.Ldate|logger.Ltime)
flags := logger.Ldate
- if *debug {
+ if *isDebug {
flags |= logger.Lshortfile | logger.Lmicroseconds
cpuProfilingFile, _ = os.Create("cpu.prof")
memProfilingFile, _ = os.Create("memory.prof")
@@ -357,7 +383,6 @@ func initConfig() (err error) {
flags |= logger.Ltime
}
log.SetFlags(flags)
-
if *nolog {
log.SetOutput(ioutil.Discard)
} else if *logfile != "" {
@@ -391,6 +416,11 @@ func initConfig() (err error) {
}
}
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
for {
if cmd != nil {
cmd.Process.Kill()
@@ -410,11 +440,21 @@ func initConfig() (err error) {
scanner := bufio.NewScanner(cmdReader)
scannerStdErr := bufio.NewScanner(cmdReaderStderr)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
for scanner.Scan() {
fmt.Println(scanner.Text())
}
}()
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
for scannerStdErr.Scan() {
fmt.Println(scannerStdErr.Text())
}
@@ -436,7 +476,7 @@ func initConfig() (err error) {
}
if *logfile == "" {
poster()
- if *debug {
+ if *isDebug {
log.Println("[profiling] cpu profiling save to file : cpu.prof")
log.Println("[profiling] memory profiling save to file : memory.prof")
log.Println("[profiling] block profiling save to file : block.prof")
@@ -444,7 +484,7 @@ func initConfig() (err error) {
log.Println("[profiling] threadcreate profiling save to file : threadcreate.prof")
}
}
- //regist services and run service
+
//regist services and run service
switch serviceName {
case "http":
@@ -454,11 +494,11 @@ func initConfig() (err error) {
case "udp":
services.Regist(serviceName, udpx.NewUDP(), udpArgs, log)
case "tserver":
- services.Regist(serviceName, tunnel.NewTunnelServerManager(), tunnelServerArgs, log)
+ services.Regist(serviceName, tunnelx.NewTunnelServerManager(), tunnelServerArgs, log)
case "tclient":
- services.Regist(serviceName, tunnel.NewTunnelClient(), tunnelClientArgs, log)
+ services.Regist(serviceName, tunnelx.NewTunnelClient(), tunnelClientArgs, log)
case "tbridge":
- services.Regist(serviceName, tunnel.NewTunnelBridge(), tunnelBridgeArgs, log)
+ services.Regist(serviceName, tunnelx.NewTunnelBridge(), tunnelBridgeArgs, log)
case "server":
services.Regist(serviceName, mux.NewMuxServerManager(), muxServerArgs, log)
case "client":
@@ -474,7 +514,6 @@ func initConfig() (err error) {
case "keygen":
services.Regist(serviceName, keygenx.NewKeygen(), keygenArgs, log)
}
-
service, err = services.Run(serviceName, nil)
if err != nil {
log.Fatalf("run service [%s] fail, ERR:%s", serviceName, err)
@@ -483,16 +522,7 @@ func initConfig() (err error) {
}
func poster() {
- fmt.Printf(`
- ######## ######## ####### ## ## ## ##
- ## ## ## ## ## ## ## ## ## ##
- ## ## ## ## ## ## ## ## ####
- ######## ######## ## ## ### ##
- ## ## ## ## ## ## ## ##
- ## ## ## ## ## ## ## ##
- ## ## ## ####### ## ## ##
-
- v%s`+" by snail , blog : http://www.host900.com/\n\n", APP_VERSION)
+ fmt.Printf(`Proxy Enterprise Version v%s`+" by snail , blog : http://www.host900.com/\n\n", APP_VERSION)
}
func saveProfiling() {
goroutine := pprof.Lookup("goroutine")
diff --git a/core/cs/client/client.go b/core/cs/client/client.go
new file mode 100644
index 0000000..034ad21
--- /dev/null
+++ b/core/cs/client/client.go
@@ -0,0 +1,132 @@
+package client
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/snail007/goproxy/core/lib/kcpcfg"
+ compressconn "github.com/snail007/goproxy/core/lib/transport"
+ encryptconn "github.com/snail007/goproxy/core/lib/transport/encrypt"
+ "github.com/snail007/goproxy/core/dst"
+ kcp "github.com/xtaci/kcp-go"
+)
+
+func TlsConnectHost(host string, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
+ h := strings.Split(host, ":")
+ port, _ := strconv.Atoi(h[1])
+ return TlsConnect(h[0], port, timeout, certBytes, keyBytes, caCertBytes)
+}
+
+func TlsConnect(host string, port, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
+ conf, err := getRequestTlsConfig(certBytes, keyBytes, caCertBytes)
+ if err != nil {
+ return
+ }
+ _conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
+ if err != nil {
+ return
+ }
+ return *tls.Client(_conn, conf), err
+}
+func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) {
+
+ var cert tls.Certificate
+ cert, err = tls.X509KeyPair(certBytes, keyBytes)
+ if err != nil {
+ return
+ }
+ serverCertPool := x509.NewCertPool()
+ caBytes := certBytes
+ if caCertBytes != nil {
+ caBytes = caCertBytes
+
+ }
+ ok := serverCertPool.AppendCertsFromPEM(caBytes)
+ if !ok {
+ err = errors.New("failed to parse root certificate")
+ }
+ block, _ := pem.Decode(caBytes)
+ if block == nil {
+ panic("failed to parse certificate PEM")
+ }
+ x509Cert, _ := x509.ParseCertificate(block.Bytes)
+ if x509Cert == nil {
+ panic("failed to parse block")
+ }
+ conf = &tls.Config{
+ RootCAs: serverCertPool,
+ Certificates: []tls.Certificate{cert},
+ InsecureSkipVerify: true,
+ ServerName: x509Cert.Subject.CommonName,
+ VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+ opts := x509.VerifyOptions{
+ Roots: serverCertPool,
+ }
+ for _, rawCert := range rawCerts {
+ cert, _ := x509.ParseCertificate(rawCert)
+ _, err := cert.Verify(opts)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+ return
+}
+
+func TCPConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
+ conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
+ return
+}
+
+func TCPSConnectHost(hostAndPort string, method, password string, compress bool, timeout int) (conn net.Conn, err error) {
+ conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
+ if err != nil {
+ return
+ }
+ if compress {
+ conn = compressconn.NewCompConn(conn)
+ }
+ conn, err = encryptconn.NewConn(conn, method, password)
+ return
+}
+
+func TOUConnectHost(hostAndPort string, method, password string, compress bool, timeout int) (conn net.Conn, err error) {
+ udpConn, err := net.ListenUDP("udp", &net.UDPAddr{})
+ if err != nil {
+ panic(err)
+ }
+ // Create a DST mux around the packet connection with the default max
+ // packet size.
+ mux := dst.NewMux(udpConn, 0)
+ conn, err = mux.Dial("dst", hostAndPort)
+ if compress {
+ conn = compressconn.NewCompConn(conn)
+ }
+ conn, err = encryptconn.NewConn(conn, method, password)
+ return
+}
+func KCPConnectHost(hostAndPort string, config kcpcfg.KCPConfigArgs) (conn net.Conn, err error) {
+ kcpconn, err := kcp.DialWithOptions(hostAndPort, config.Block, *config.DataShard, *config.ParityShard)
+ if err != nil {
+ return
+ }
+ kcpconn.SetStreamMode(true)
+ kcpconn.SetWriteDelay(true)
+ kcpconn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion)
+ kcpconn.SetMtu(*config.MTU)
+ kcpconn.SetWindowSize(*config.SndWnd, *config.RcvWnd)
+ kcpconn.SetACKNoDelay(*config.AckNodelay)
+ if *config.NoComp {
+ return kcpconn, err
+ }
+ return compressconn.NewCompStream(kcpconn), err
+}
diff --git a/core/cs/server/server.go b/core/cs/server/server.go
new file mode 100644
index 0000000..3a46b9f
--- /dev/null
+++ b/core/cs/server/server.go
@@ -0,0 +1,342 @@
+package server
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "fmt"
+
+ logger "log"
+ "net"
+ "runtime/debug"
+ "strconv"
+
+ compressconn "github.com/snail007/goproxy/core/lib/transport"
+ transportc "github.com/snail007/goproxy/core/lib/transport"
+ encryptconn "github.com/snail007/goproxy/core/lib/transport/encrypt"
+ tou "github.com/snail007/goproxy/core/dst"
+
+ "github.com/snail007/goproxy/core/lib/kcpcfg"
+
+ kcp "github.com/xtaci/kcp-go"
+)
+
+func init() {
+
+}
+
+type ServerChannel struct {
+ ip string
+ port int
+ Listener *net.Listener
+ UDPListener *net.UDPConn
+ errAcceptHandler func(err error)
+ log *logger.Logger
+ TOUServer *tou.Mux
+}
+
+func NewServerChannel(ip string, port int, log *logger.Logger) ServerChannel {
+ return ServerChannel{
+ ip: ip,
+ port: port,
+ log: log,
+ errAcceptHandler: func(err error) {
+ log.Printf("accept error , ERR:%s", err)
+ },
+ }
+}
+func NewServerChannelHost(host string, log *logger.Logger) ServerChannel {
+ h, port, _ := net.SplitHostPort(host)
+ p, _ := strconv.Atoi(port)
+ return ServerChannel{
+ ip: h,
+ port: p,
+ log: log,
+ errAcceptHandler: func(err error) {
+ log.Printf("accept error , ERR:%s", err)
+ },
+ }
+}
+func (s *ServerChannel) SetErrAcceptHandler(fn func(err error)) {
+ s.errAcceptHandler = fn
+}
+func (s *ServerChannel) ListenSingleTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) {
+ return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, true)
+
+}
+func (s *ServerChannel) ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) {
+ return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, false)
+}
+func (s *ServerChannel) _ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn), single bool) (err error) {
+ s.Listener, err = s.listenTLS(s.ip, s.port, certBytes, keyBytes, caCertBytes, single)
+ if err == nil {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("ListenTLS crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ for {
+ var conn net.Conn
+ conn, err = (*s.Listener).Accept()
+ if err == nil {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("tls connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ fn(conn)
+ }()
+ } else {
+ s.errAcceptHandler(err)
+ (*s.Listener).Close()
+ break
+ }
+ }
+ }()
+ }
+ return
+}
+func (s *ServerChannel) listenTLS(ip string, port int, certBytes, keyBytes, caCertBytes []byte, single bool) (ln *net.Listener, err error) {
+ var cert tls.Certificate
+ cert, err = tls.X509KeyPair(certBytes, keyBytes)
+ if err != nil {
+ return
+ }
+ config := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ if !single {
+ clientCertPool := x509.NewCertPool()
+ caBytes := certBytes
+ if caCertBytes != nil {
+ caBytes = caCertBytes
+ }
+ ok := clientCertPool.AppendCertsFromPEM(caBytes)
+ if !ok {
+ err = errors.New("failed to parse root certificate")
+ }
+ config.ClientCAs = clientCertPool
+ config.ClientAuth = tls.RequireAndVerifyClientCert
+ }
+ _ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config)
+ if err == nil {
+ ln = &_ln
+ }
+ return
+}
+func (s *ServerChannel) ListenTCPS(method, password string, compress bool, fn func(conn net.Conn)) (err error) {
+ _, err = encryptconn.NewCipher(method, password)
+ if err != nil {
+ return
+ }
+ return s.ListenTCP(func(c net.Conn) {
+ if compress {
+ c = transportc.NewCompConn(c)
+ }
+ c, _ = encryptconn.NewConn(c, method, password)
+ fn(c)
+ })
+}
+func (s *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) {
+ var l net.Listener
+ l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.ip, s.port))
+ if err == nil {
+ s.Listener = &l
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ for {
+ var conn net.Conn
+ conn, err = (*s.Listener).Accept()
+ if err == nil {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ fn(conn)
+ }()
+ } else {
+ s.errAcceptHandler(err)
+ (*s.Listener).Close()
+ break
+ }
+ }
+ }()
+ }
+ return
+}
+func (s *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) {
+ addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port}
+ l, err := net.ListenUDP("udp", addr)
+ if err == nil {
+ s.UDPListener = l
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ for {
+ var buf = make([]byte, 2048)
+ n, srcAddr, err := (*s.UDPListener).ReadFromUDP(buf)
+ if err == nil {
+ packet := buf[0:n]
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ fn(packet, addr, srcAddr)
+ }()
+ } else {
+ s.errAcceptHandler(err)
+ (*s.UDPListener).Close()
+ break
+ }
+ }
+ }()
+ }
+ return
+}
+func (s *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net.Conn), log *logger.Logger) (err error) {
+ lis, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", s.ip, s.port), config.Block, *config.DataShard, *config.ParityShard)
+ if err == nil {
+ if err = lis.SetDSCP(*config.DSCP); err != nil {
+ log.Println("SetDSCP:", err)
+ return
+ }
+ if err = lis.SetReadBuffer(*config.SockBuf); err != nil {
+ log.Println("SetReadBuffer:", err)
+ return
+ }
+ if err = lis.SetWriteBuffer(*config.SockBuf); err != nil {
+ log.Println("SetWriteBuffer:", err)
+ return
+ }
+ s.Listener = new(net.Listener)
+ *s.Listener = lis
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("ListenKCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ for {
+ //var conn net.Conn
+ conn, err := lis.AcceptKCP()
+ if err == nil {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("kcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ conn.SetStreamMode(true)
+ conn.SetWriteDelay(true)
+ conn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion)
+ conn.SetMtu(*config.MTU)
+ conn.SetWindowSize(*config.SndWnd, *config.RcvWnd)
+ conn.SetACKNoDelay(*config.AckNodelay)
+ if *config.NoComp {
+ fn(conn)
+ } else {
+ cconn := transportc.NewCompStream(conn)
+ fn(cconn)
+ }
+ }()
+ } else {
+ s.errAcceptHandler(err)
+ (*s.Listener).Close()
+ break
+ }
+ }
+ }()
+ }
+ return
+}
+
+func (s *ServerChannel) ListenTOU(method, password string, compress bool, fn func(conn net.Conn)) (err error) {
+ addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port}
+ s.UDPListener, err = net.ListenUDP("udp", addr)
+ if err != nil {
+ s.log.Println(err)
+ return
+ }
+ s.TOUServer = tou.NewMux(s.UDPListener, 0)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("ListenRUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ for {
+ var conn net.Conn
+ conn, err = (*s.TOUServer).Accept()
+ if err == nil {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
+ }
+ }()
+ if compress {
+ conn = compressconn.NewCompConn(conn)
+ }
+ conn, err = encryptconn.NewConn(conn, method, password)
+ if err != nil {
+ conn.Close()
+ s.log.Println(err)
+ return
+ }
+ fn(conn)
+ }()
+ } else {
+ s.errAcceptHandler(err)
+ s.TOUServer.Close()
+ s.UDPListener.Close()
+ break
+ }
+ }
+ }()
+
+ return
+}
+func (s *ServerChannel) Close() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack()))
+ }
+ }()
+ if s.Listener != nil && *s.Listener != nil {
+ (*s.Listener).Close()
+ }
+ if s.TOUServer != nil {
+ s.TOUServer.Close()
+ }
+ if s.UDPListener != nil {
+ s.UDPListener.Close()
+ }
+}
+func (s *ServerChannel) Addr() string {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack()))
+ }
+ }()
+ if s.Listener != nil && *s.Listener != nil {
+ return (*s.Listener).Addr().String()
+ }
+
+ if s.UDPListener != nil {
+ return s.UDPListener.LocalAddr().String()
+ }
+ return ""
+}
diff --git a/core/cs/tests/transport_test.go b/core/cs/tests/transport_test.go
new file mode 100644
index 0000000..d530314
--- /dev/null
+++ b/core/cs/tests/transport_test.go
@@ -0,0 +1,49 @@
+package tests
+
+import (
+ "log"
+ "net"
+ "os"
+ "testing"
+
+ ctransport "github.com/snail007/goproxy/core/cs/client"
+ stransport "github.com/snail007/goproxy/core/cs/server"
+)
+
+func TestTCPS(t *testing.T) {
+ l := log.New(os.Stderr, "", log.LstdFlags)
+ s := stransport.NewServerChannelHost(":", l)
+ err := s.ListenTCPS("aes-256-cfb", "password", true, func(inconn net.Conn) {
+ buf := make([]byte, 2048)
+ _, err := inconn.Read(buf)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ _, err = inconn.Write([]byte("okay"))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ client, err := ctransport.TCPSConnectHost((*s.Listener).Addr().String(), "aes-256-cfb", "password", true, 1000)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+ _, err = client.Write([]byte("test"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ b := make([]byte, 20)
+ n, err := client.Read(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(b[:n]) != "okay" {
+ t.Fatalf("client revecive okay excepted,revecived : %s", string(b[:n]))
+ }
+}
diff --git a/core/dst/conn.go b/core/dst/conn.go
new file mode 100644
index 0000000..4543ec1
--- /dev/null
+++ b/core/dst/conn.go
@@ -0,0 +1,594 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "runtime/debug"
+
+ "math/rand"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const (
+ defExpTime = 100 * time.Millisecond // N * (4 * RTT + RTTVar + SYN)
+ expCountClose = 8 // close connection after this many Exps
+ minTimeClose = 5 * time.Second // if at least this long has passed
+ maxInputBuffer = 8 << 20 // bytes
+ muxBufferPackets = 128 // buffer size of channel between mux and reader routine
+ rttMeasureWindow = 32 // number of packets to track for RTT averaging
+ rttMeasureSample = 128 // Sample every ... packet for RTT
+
+ // number of bytes to subtract from MTU when chunking data, to try to
+ // avoid fragmentation
+ sliceOverhead = 8 /*pppoe, similar*/ + 20 /*ipv4*/ + 8 /*udp*/ + 16 /*dst*/
+)
+
+func init() {
+ // Properly seed the random number generator that we use for sequence
+ // numbers and stuff.
+ buf := make([]byte, 8)
+ if n, err := crand.Read(buf); n != 8 || err != nil {
+ panic("init random failure")
+ }
+ rand.Seed(int64(binary.BigEndian.Uint64(buf)))
+}
+
+// TODO: export this interface when it's usable from the outside
+type congestionController interface {
+ Ack()
+ NegAck()
+ Exp()
+ SendWindow() int
+ PacketRate() int // PPS
+ UpdateRTT(time.Duration)
+}
+
+// Conn is an SDT connection carried over a Mux.
+type Conn struct {
+ // Set at creation, thereafter immutable:
+
+ mux *Mux
+ dst net.Addr
+ connID connectionID
+ remoteConnID connectionID
+ in chan packet
+ cc congestionController
+ packetSize int
+ closed chan struct{}
+ closeOnce sync.Once
+
+ // Touched by more than one goroutine, needs locking.
+
+ nextSeqNoMut sync.Mutex
+ nextSeqNo sequenceNo
+
+ inbufMut sync.Mutex
+ inbufCond *sync.Cond
+ inbuf bytes.Buffer
+
+ expMut sync.Mutex
+ exp *time.Timer
+
+ sendBuffer *sendBuffer // goroutine safe
+
+ packetDelays [rttMeasureWindow]time.Duration
+ packetDelaysSlot int
+ packetDelaysMut sync.Mutex
+
+ // Owned by the reader routine, needs no locking
+
+ recvBuffer packetList
+ nextRecvSeqNo sequenceNo
+ lastAckedSeqNo sequenceNo
+ lastNegAckedSeqNo sequenceNo
+ expCount int
+ expReset time.Time
+
+ // Only accessed atomically
+
+ packetsIn int64
+ packetsOut int64
+ bytesIn int64
+ bytesOut int64
+ resentPackets int64
+ droppedPackets int64
+ outOfOrderPackets int64
+
+ // Special
+
+ debugResetRecvSeqNo chan sequenceNo
+}
+
+func newConn(m *Mux, dst net.Addr) *Conn {
+ conn := &Conn{
+ mux: m,
+ dst: dst,
+ nextSeqNo: sequenceNo(rand.Uint32()),
+ packetSize: maxPacketSize,
+ in: make(chan packet, muxBufferPackets),
+ closed: make(chan struct{}),
+ sendBuffer: newSendBuffer(m),
+ exp: time.NewTimer(defExpTime),
+ debugResetRecvSeqNo: make(chan sequenceNo),
+ expReset: time.Now(),
+ }
+
+ conn.lastAckedSeqNo = conn.nextSeqNo - 1
+ conn.inbufCond = sync.NewCond(&conn.inbufMut)
+
+ conn.cc = newWindowCC()
+ conn.sendBuffer.SetWindowAndRate(conn.cc.SendWindow(), conn.cc.PacketRate())
+ conn.recvBuffer.Resize(128)
+
+ return conn
+}
+
+func (c *Conn) start() {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ c.reader()
+ }()
+}
+
+func (c *Conn) reader() {
+ if debugConnection {
+ log.Println(c, "reader() starting")
+ defer log.Println(c, "reader() exiting")
+ }
+
+ for {
+ select {
+ case <-c.closed:
+ // Ack any received but not yet acked messages.
+ c.sendAck(0)
+
+ // Send a shutdown message.
+ c.nextSeqNoMut.Lock()
+ c.mux.write(packet{
+ src: c.connID,
+ dst: c.dst,
+ hdr: header{
+ packetType: typeShutdown,
+ connID: c.remoteConnID,
+ sequenceNo: c.nextSeqNo,
+ },
+ })
+ c.nextSeqNo++
+ c.nextSeqNoMut.Unlock()
+ atomic.AddInt64(&c.packetsOut, 1)
+ atomic.AddInt64(&c.bytesOut, dstHeaderLen)
+ return
+
+ case pkt := <-c.in:
+ atomic.AddInt64(&c.packetsIn, 1)
+ atomic.AddInt64(&c.bytesIn, dstHeaderLen+int64(len(pkt.data)))
+
+ c.expCount = 1
+
+ switch pkt.hdr.packetType {
+ case typeData:
+ c.rcvData(pkt)
+ case typeAck:
+ c.rcvAck(pkt)
+ case typeNegAck:
+ c.rcvNegAck(pkt)
+ case typeShutdown:
+ c.rcvShutdown(pkt)
+ default:
+ log.Println("Unhandled packet", pkt)
+ continue
+ }
+
+ case <-c.exp.C:
+ c.eventExp()
+ c.resetExp()
+
+ case n := <-c.debugResetRecvSeqNo:
+ // Back door for testing
+ c.lastAckedSeqNo = n - 1
+ c.nextRecvSeqNo = n
+ }
+ }
+}
+
+func (c *Conn) eventExp() {
+ c.expCount++
+
+ if c.sendBuffer.lost.Len() > 0 || c.sendBuffer.send.Len() > 0 {
+ c.cc.Exp()
+ c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate())
+ c.sendBuffer.ScheduleResend()
+
+ if debugConnection {
+ log.Println(c, "did resends due to Exp")
+ }
+
+ if c.expCount > expCountClose && time.Since(c.expReset) > minTimeClose {
+ if debugConnection {
+ log.Println(c, "close due to Exp")
+ }
+
+ // We're shutting down due to repeated exp:s. Don't wait for the
+ // send buffer to drain, which it would otherwise do in
+ // c.Close()..
+ c.sendBuffer.CrashStop()
+
+ c.Close()
+ }
+ }
+}
+
+func (c *Conn) rcvAck(pkt packet) {
+ ack := pkt.hdr.sequenceNo
+
+ if debugConnection {
+ log.Printf("%v read Ack %v", c, ack)
+ }
+
+ c.cc.Ack()
+
+ if ack%rttMeasureSample == 0 {
+ if ts := timestamp(binary.BigEndian.Uint32(pkt.data)); ts > 0 {
+ if delay := time.Duration(timestampMicros()-ts) * time.Microsecond; delay > 0 {
+ c.packetDelaysMut.Lock()
+ c.packetDelays[c.packetDelaysSlot] = delay
+ c.packetDelaysSlot = (c.packetDelaysSlot + 1) % len(c.packetDelays)
+ c.packetDelaysMut.Unlock()
+
+ if rtt, n := c.averageDelay(); n > 8 {
+ c.cc.UpdateRTT(rtt)
+ }
+ }
+ }
+ }
+
+ c.sendBuffer.Acknowledge(ack)
+ c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate())
+
+ c.resetExp()
+}
+
+func (c *Conn) averageDelay() (time.Duration, int) {
+ var total time.Duration
+ var n int
+
+ c.packetDelaysMut.Lock()
+ for _, d := range c.packetDelays {
+ if d != 0 {
+ total += d
+ n++
+ }
+ }
+ c.packetDelaysMut.Unlock()
+
+ if n == 0 {
+ return 0, 0
+ }
+ return total / time.Duration(n), n
+}
+
+func (c *Conn) rcvNegAck(pkt packet) {
+ nak := pkt.hdr.sequenceNo
+
+ if debugConnection {
+ log.Printf("%v read NegAck %v", c, nak)
+ }
+
+ c.sendBuffer.NegativeAck(nak)
+
+ //c.cc.NegAck()
+ c.resetExp()
+}
+
+func (c *Conn) rcvShutdown(pkt packet) {
+ // XXX: We accept shutdown packets somewhat from the future since the
+ // sender will number the shutdown after any packets that might still be
+ // in the write buffer. This should be fixed to let the write buffer empty
+ // on close and reduce the window here.
+ if pkt.LessSeq(c.nextRecvSeqNo + 128) {
+ if debugConnection {
+ log.Println(c, "close due to shutdown")
+ }
+ c.Close()
+ }
+}
+
+func (c *Conn) rcvData(pkt packet) {
+ if debugConnection {
+ log.Println(c, "recv data", pkt.hdr)
+ }
+
+ if pkt.LessSeq(c.nextRecvSeqNo) {
+ if debugConnection {
+ log.Printf("%v old packet received; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo)
+ }
+ atomic.AddInt64(&c.droppedPackets, 1)
+ return
+ }
+
+ if debugConnection {
+ log.Println(c, "into recv buffer:", pkt)
+ }
+ c.recvBuffer.InsertSorted(pkt)
+ if c.recvBuffer.LowestSeq() == c.nextRecvSeqNo {
+ for _, pkt := range c.recvBuffer.PopSequence(^sequenceNo(0)) {
+ if debugConnection {
+ log.Println(c, "from recv buffer:", pkt)
+ }
+
+ // An in-sequence packet.
+
+ c.nextRecvSeqNo = pkt.hdr.sequenceNo + 1
+
+ c.sendAck(pkt.hdr.timestamp)
+
+ c.inbufMut.Lock()
+ for c.inbuf.Len() > len(pkt.data)+maxInputBuffer {
+ c.inbufCond.Wait()
+ select {
+ case <-c.closed:
+ return
+ default:
+ }
+ }
+
+ c.inbuf.Write(pkt.data)
+ c.inbufCond.Broadcast()
+ c.inbufMut.Unlock()
+ }
+ } else {
+ if debugConnection {
+ log.Printf("%v lost; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo)
+ }
+ c.recvBuffer.InsertSorted(pkt)
+ c.sendNegAck()
+ atomic.AddInt64(&c.outOfOrderPackets, 1)
+ }
+}
+
+func (c *Conn) sendAck(ts timestamp) {
+ if c.lastAckedSeqNo == c.nextRecvSeqNo {
+ return
+ }
+
+ var buf [4]byte
+ binary.BigEndian.PutUint32(buf[:], uint32(ts))
+ c.mux.write(packet{
+ src: c.connID,
+ dst: c.dst,
+ hdr: header{
+ packetType: typeAck,
+ connID: c.remoteConnID,
+ sequenceNo: c.nextRecvSeqNo,
+ },
+ data: buf[:],
+ })
+
+ atomic.AddInt64(&c.packetsOut, 1)
+ atomic.AddInt64(&c.bytesOut, dstHeaderLen)
+ if debugConnection {
+ log.Printf("%v send Ack %v", c, c.nextRecvSeqNo)
+ }
+
+ c.lastAckedSeqNo = c.nextRecvSeqNo
+}
+
+func (c *Conn) sendNegAck() {
+ if c.lastNegAckedSeqNo == c.nextRecvSeqNo {
+ return
+ }
+
+ c.mux.write(packet{
+ src: c.connID,
+ dst: c.dst,
+ hdr: header{
+ packetType: typeNegAck,
+ connID: c.remoteConnID,
+ sequenceNo: c.nextRecvSeqNo,
+ },
+ })
+
+ atomic.AddInt64(&c.packetsOut, 1)
+ atomic.AddInt64(&c.bytesOut, dstHeaderLen)
+ if debugConnection {
+ log.Printf("%v send NegAck %v", c, c.nextRecvSeqNo)
+ }
+
+ c.lastNegAckedSeqNo = c.nextRecvSeqNo
+}
+
+func (c *Conn) resetExp() {
+ d, _ := c.averageDelay()
+ d = d*4 + 10*time.Millisecond
+
+ if d < defExpTime {
+ d = defExpTime
+ }
+
+ c.expMut.Lock()
+ c.exp.Reset(d)
+ c.expMut.Unlock()
+}
+
+// String returns a string representation of the connection.
+func (c *Conn) String() string {
+ return fmt.Sprintf("%v/%v/%v", c.connID, c.LocalAddr(), c.RemoteAddr())
+}
+
+// Read reads data from the connection.
+// Read can be made to time out and return a Error with Timeout() == true
+// after a fixed time limit; see SetDeadline and SetReadDeadline.
+func (c *Conn) Read(b []byte) (n int, err error) {
+ c.inbufMut.Lock()
+ defer c.inbufMut.Unlock()
+ for c.inbuf.Len() == 0 {
+ select {
+ case <-c.closed:
+ return 0, io.EOF
+ default:
+ }
+ c.inbufCond.Wait()
+ }
+ return c.inbuf.Read(b)
+}
+
+// Write writes data to the connection.
+// Write can be made to time out and return a Error with Timeout() == true
+// after a fixed time limit; see SetDeadline and SetWriteDeadline.
+func (c *Conn) Write(b []byte) (n int, err error) {
+ select {
+ case <-c.closed:
+ return 0, ErrClosedConn
+ default:
+ }
+
+ sent := 0
+ sliceSize := c.packetSize - sliceOverhead
+ for i := 0; i < len(b); i += sliceSize {
+ nxt := i + sliceSize
+ if nxt > len(b) {
+ nxt = len(b)
+ }
+ slice := b[i:nxt]
+ sliceCopy := c.mux.buffers.Get().([]byte)[:len(slice)]
+ copy(sliceCopy, slice)
+
+ c.nextSeqNoMut.Lock()
+ pkt := packet{
+ src: c.connID,
+ dst: c.dst,
+ hdr: header{
+ packetType: typeData,
+ sequenceNo: c.nextSeqNo,
+ connID: c.remoteConnID,
+ },
+ data: sliceCopy,
+ }
+ c.nextSeqNo++
+ c.nextSeqNoMut.Unlock()
+
+ if err := c.sendBuffer.Write(pkt); err != nil {
+ return sent, err
+ }
+
+ atomic.AddInt64(&c.packetsOut, 1)
+ atomic.AddInt64(&c.bytesOut, int64(len(slice)+dstHeaderLen))
+
+ sent += len(slice)
+ c.resetExp()
+ }
+ return sent, nil
+}
+
+// Close closes the connection.
+// Any blocked Read or Write operations will be unblocked and return errors.
+func (c *Conn) Close() error {
+ c.closeOnce.Do(func() {
+ if debugConnection {
+ log.Println(c, "explicit close start")
+ defer log.Println(c, "explicit close done")
+ }
+
+ // XXX: Ugly hack to implement lingering sockets...
+ time.Sleep(4 * defExpTime)
+
+ c.sendBuffer.Stop()
+ c.mux.removeConn(c)
+ close(c.closed)
+
+ c.inbufMut.Lock()
+ c.inbufCond.Broadcast()
+ c.inbufMut.Unlock()
+ })
+ return nil
+}
+
+// LocalAddr returns the local network address.
+func (c *Conn) LocalAddr() net.Addr {
+ return c.mux.Addr()
+}
+
+// RemoteAddr returns the remote network address.
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.dst
+}
+
+// SetDeadline sets the read and write deadlines associated
+// with the connection. It is equivalent to calling both
+// SetReadDeadline and SetWriteDeadline.
+//
+// A deadline is an absolute time after which I/O operations
+// fail with a timeout (see type Error) instead of
+// blocking. The deadline applies to all future I/O, not just
+// the immediately following call to Read or Write.
+//
+// An idle timeout can be implemented by repeatedly extending
+// the deadline after successful Read or Write calls.
+//
+// A zero value for t means I/O operations will not time out.
+//
+// BUG(jb): SetDeadline is not implemented.
+func (c *Conn) SetDeadline(t time.Time) error {
+ return ErrNotImplemented
+}
+
+// SetReadDeadline sets the deadline for future Read calls.
+// A zero value for t means Read will not time out.
+//
+// BUG(jb): SetReadDeadline is not implemented.
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ return ErrNotImplemented
+}
+
+// SetWriteDeadline sets the deadline for future Write calls.
+// Even if write times out, it may return n > 0, indicating that
+// some of the data was successfully written.
+// A zero value for t means Write will not time out.
+//
+// BUG(jb): SetWriteDeadline is not implemented.
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ return ErrNotImplemented
+}
+
+type Statistics struct {
+ DataPacketsIn int64
+ DataPacketsOut int64
+ DataBytesIn int64
+ DataBytesOut int64
+ ResentPackets int64
+ DroppedPackets int64
+ OutOfOrderPackets int64
+}
+
+// String returns a printable represetnation of the Statistics.
+func (s Statistics) String() string {
+ return fmt.Sprintf("PktsIn: %d, PktsOut: %d, BytesIn: %d, BytesOut: %d, PktsResent: %d, PktsDropped: %d, PktsOutOfOrder: %d",
+ s.DataPacketsIn, s.DataPacketsOut, s.DataBytesIn, s.DataBytesOut, s.ResentPackets, s.DroppedPackets, s.OutOfOrderPackets)
+}
+
+// GetStatistics returns a snapsht of the current connection statistics.
+func (c *Conn) GetStatistics() Statistics {
+ return Statistics{
+ DataPacketsIn: atomic.LoadInt64(&c.packetsIn),
+ DataPacketsOut: atomic.LoadInt64(&c.packetsOut),
+ DataBytesIn: atomic.LoadInt64(&c.bytesIn),
+ DataBytesOut: atomic.LoadInt64(&c.bytesOut),
+ ResentPackets: atomic.LoadInt64(&c.resentPackets),
+ DroppedPackets: atomic.LoadInt64(&c.droppedPackets),
+ OutOfOrderPackets: atomic.LoadInt64(&c.outOfOrderPackets),
+ }
+}
diff --git a/core/dst/cookie.go b/core/dst/cookie.go
new file mode 100644
index 0000000..0daeb8b
--- /dev/null
+++ b/core/dst/cookie.go
@@ -0,0 +1,29 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/binary"
+ "net"
+)
+
+var cookieKey = make([]byte, 16)
+
+func init() {
+ _, err := rand.Reader.Read(cookieKey)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func cookie(remote net.Addr) uint32 {
+ hash := sha256.New()
+ hash.Write([]byte(remote.String()))
+ hash.Write(cookieKey)
+ bs := hash.Sum(nil)
+ return binary.BigEndian.Uint32(bs)
+}
diff --git a/core/dst/debug.go b/core/dst/debug.go
new file mode 100644
index 0000000..ed9f3fe
--- /dev/null
+++ b/core/dst/debug.go
@@ -0,0 +1,26 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ "os"
+ "strings"
+)
+
+var (
+ debugConnection bool
+ debugMux bool
+ debugCC bool
+)
+
+func init() {
+ debug := make(map[string]bool)
+ for _, s := range strings.Split(os.Getenv("DSTDEBUG"), ",") {
+ debug[strings.TrimSpace(s)] = true
+ }
+ debugConnection = debug["conn"]
+ debugMux = debug["mux"]
+ debugCC = debug["cc"]
+}
diff --git a/core/dst/doc.go b/core/dst/doc.go
new file mode 100644
index 0000000..fe4c768
--- /dev/null
+++ b/core/dst/doc.go
@@ -0,0 +1,12 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+/*
+
+Package dst implements the Datagram Stream Transfer protocol.
+
+DST is a way to get reliable stream connections (like TCP) on top of UDP.
+
+*/
+package dst
diff --git a/core/dst/errors.go b/core/dst/errors.go
new file mode 100644
index 0000000..dd4eb0e
--- /dev/null
+++ b/core/dst/errors.go
@@ -0,0 +1,23 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+// Error represents the various dst-internal error conditions.
+type Error struct {
+ Err string
+}
+
+// Error returns a string representation of the error.
+func (e Error) Error() string {
+ return e.Err
+}
+
+var (
+ ErrClosedConn = &Error{"operation on closed connection"}
+ ErrClosedMux = &Error{"operation on closed mux"}
+ ErrHandshakeTimeout = &Error{"handshake timeout"}
+ ErrNotDST = &Error{"network is not dst"}
+ ErrNotImplemented = &Error{"not implemented"}
+)
diff --git a/core/dst/mux.go b/core/dst/mux.go
new file mode 100644
index 0000000..967e73c
--- /dev/null
+++ b/core/dst/mux.go
@@ -0,0 +1,429 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ "fmt"
+ "runtime/debug"
+
+ "net"
+ "sync"
+ "time"
+)
+
+const (
+ maxIncomingRequests = 1024
+ maxPacketSize = 500
+ handshakeTimeout = 5 * time.Second
+ handshakeInterval = 1 * time.Second
+)
+
+// Mux is a UDP multiplexer of DST connections.
+type Mux struct {
+ conn net.PacketConn
+ packetSize int
+
+ conns map[connectionID]*Conn
+ handshakes map[connectionID]chan packet
+ connsMut sync.Mutex
+
+ incoming chan *Conn
+ closed chan struct{}
+ closeOnce sync.Once
+
+ buffers *sync.Pool
+}
+
+// NewMux creates a new DST Mux on top of a packet connection.
+func NewMux(conn net.PacketConn, packetSize int) *Mux {
+ if packetSize <= 0 {
+ packetSize = maxPacketSize
+ }
+ m := &Mux{
+ conn: conn,
+ packetSize: packetSize,
+ conns: map[connectionID]*Conn{},
+ handshakes: make(map[connectionID]chan packet),
+ incoming: make(chan *Conn, maxIncomingRequests),
+ closed: make(chan struct{}),
+ buffers: &sync.Pool{
+ New: func() interface{} {
+ return make([]byte, packetSize)
+ },
+ },
+ }
+
+ // Attempt to maximize buffer space. Start at 16 MB and work downwards 0.5
+ // MB at a time.
+
+ if conn, ok := conn.(*net.UDPConn); ok {
+ for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
+ err := conn.SetReadBuffer(buf)
+ if err == nil {
+ if debugMux {
+ log.Println(m, "read buffer is", buf)
+ }
+ break
+ }
+ }
+ for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
+ err := conn.SetWriteBuffer(buf)
+ if err == nil {
+ if debugMux {
+ log.Println(m, "write buffer is", buf)
+ }
+ break
+ }
+ }
+ }
+
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ m.readerLoop()
+ }()
+ return m
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (m *Mux) Accept() (net.Conn, error) {
+ return m.AcceptDST()
+}
+
+// AcceptDST waits for and returns the next connection to the listener.
+func (m *Mux) AcceptDST() (*Conn, error) {
+ conn, ok := <-m.incoming
+ if !ok {
+ return nil, ErrClosedMux
+ }
+ return conn, nil
+}
+
+// Close closes the listener.
+// Any blocked Accept operations will be unblocked and return errors.
+func (m *Mux) Close() error {
+ var err error = ErrClosedMux
+ m.closeOnce.Do(func() {
+ err = m.conn.Close()
+ close(m.incoming)
+ close(m.closed)
+ })
+ return err
+}
+
+// Addr returns the listener's network address.
+func (m *Mux) Addr() net.Addr {
+ return m.conn.LocalAddr()
+}
+
+// Dial connects to the address on the named network.
+//
+// Network must be "dst".
+//
+// Addresses have the form host:port. If host is a literal IPv6 address or
+// host name, it must be enclosed in square brackets as in "[::1]:80",
+// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
+// SplitHostPort manipulate addresses in this form.
+//
+// Examples:
+// Dial("dst", "12.34.56.78:80")
+// Dial("dst", "google.com:http")
+// Dial("dst", "[2001:db8::1]:http")
+// Dial("dst", "[fe80::1%lo0]:80")
+func (m *Mux) Dial(network, addr string) (net.Conn, error) {
+ return m.DialDST(network, addr)
+}
+
+// Dial connects to the address on the named network.
+//
+// Network must be "dst".
+//
+// Addresses have the form host:port. If host is a literal IPv6 address or
+// host name, it must be enclosed in square brackets as in "[::1]:80",
+// "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
+// SplitHostPort manipulate addresses in this form.
+//
+// Examples:
+// Dial("dst", "12.34.56.78:80")
+// Dial("dst", "google.com:http")
+// Dial("dst", "[2001:db8::1]:http")
+// Dial("dst", "[fe80::1%lo0]:80")
+func (m *Mux) DialDST(network, addr string) (*Conn, error) {
+ if network != "dst" {
+ return nil, ErrNotDST
+ }
+
+ dst, err := net.ResolveUDPAddr("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+
+ resp := make(chan packet)
+
+ m.connsMut.Lock()
+ connID := m.newConnID()
+ m.handshakes[connID] = resp
+ m.connsMut.Unlock()
+
+ conn, err := m.clientHandshake(dst, connID, resp)
+
+ m.connsMut.Lock()
+ defer m.connsMut.Unlock()
+ delete(m.handshakes, connID)
+
+ if err != nil {
+ return nil, err
+ }
+
+ m.conns[connID] = conn
+ return conn, nil
+}
+
+// handshake performs the client side handshake (i.e. Dial)
+func (m *Mux) clientHandshake(dst net.Addr, connID connectionID, resp chan packet) (*Conn, error) {
+ if debugMux {
+ log.Printf("%v dial %v connID %v", m, dst, connID)
+ }
+
+ nextHandshake := time.NewTimer(0)
+ defer nextHandshake.Stop()
+
+ handshakeTimeout := time.NewTimer(handshakeTimeout)
+ defer handshakeTimeout.Stop()
+
+ var remoteCookie uint32
+ seqNo := randomSeqNo()
+
+ for {
+ select {
+ case <-m.closed:
+ // Failure. The mux has been closed.
+ return nil, ErrClosedConn
+
+ case <-handshakeTimeout.C:
+ // Handshake timeout. Close and abort.
+ return nil, ErrHandshakeTimeout
+
+ case <-nextHandshake.C:
+ // Send a handshake request.
+
+ m.write(packet{
+ src: connID,
+ dst: dst,
+ hdr: header{
+ packetType: typeHandshake,
+ flags: flagRequest,
+ connID: 0,
+ sequenceNo: seqNo,
+ timestamp: timestampMicros(),
+ },
+ data: handshakeData{uint32(m.packetSize), connID, remoteCookie}.marshal(),
+ })
+ nextHandshake.Reset(handshakeInterval)
+
+ case pkt := <-resp:
+ hd := unmarshalHandshakeData(pkt.data)
+
+ if pkt.hdr.flags&flagCookie == flagCookie {
+ // We should resend the handshake request with a different cookie value.
+ remoteCookie = hd.cookie
+ nextHandshake.Reset(0)
+ } else if pkt.hdr.flags&flagResponse == flagResponse {
+ // Successfull handshake response.
+ conn := newConn(m, dst)
+
+ conn.connID = connID
+ conn.remoteConnID = hd.connID
+ conn.nextRecvSeqNo = pkt.hdr.sequenceNo + 1
+ conn.packetSize = int(hd.packetSize)
+ if conn.packetSize > m.packetSize {
+ conn.packetSize = m.packetSize
+ }
+
+ conn.nextSeqNo = seqNo + 1
+
+ conn.start()
+
+ return conn, nil
+ }
+ }
+ }
+}
+
+func (m *Mux) readerLoop() {
+ buf := make([]byte, m.packetSize)
+ for {
+ buf = buf[:cap(buf)]
+ n, from, err := m.conn.ReadFrom(buf)
+ if err != nil {
+ m.Close()
+ return
+ }
+ buf = buf[:n]
+
+ hdr := unmarshalHeader(buf)
+
+ var bufCopy []byte
+ if len(buf) > dstHeaderLen {
+ bufCopy = m.buffers.Get().([]byte)[:len(buf)-dstHeaderLen]
+ copy(bufCopy, buf[dstHeaderLen:])
+ }
+
+ pkt := packet{hdr: hdr, data: bufCopy}
+ if debugMux {
+ log.Println(m, "read", pkt)
+ }
+
+ if hdr.packetType == typeHandshake {
+ m.incomingHandshake(from, hdr, bufCopy)
+ } else {
+ m.connsMut.Lock()
+ conn, ok := m.conns[hdr.connID]
+ m.connsMut.Unlock()
+
+ if ok {
+ conn.in <- packet{
+ dst: nil,
+ hdr: hdr,
+ data: bufCopy,
+ }
+ } else if debugMux && hdr.packetType != typeShutdown {
+ log.Printf("packet %v for unknown conn %v", hdr, hdr.connID)
+ }
+ }
+ }
+}
+
+func (m *Mux) incomingHandshake(from net.Addr, hdr header, data []byte) {
+ if hdr.connID == 0 {
+ // A new incoming handshake request.
+ m.incomingHandshakeRequest(from, hdr, data)
+ } else {
+ // A response to an ongoing handshake.
+ m.incomingHandshakeResponse(from, hdr, data)
+ }
+}
+
+func (m *Mux) incomingHandshakeRequest(from net.Addr, hdr header, data []byte) {
+ if hdr.flags&flagRequest != flagRequest {
+ log.Printf("Handshake pattern with flags 0x%x to connID zero", hdr.flags)
+ return
+ }
+
+ hd := unmarshalHandshakeData(data)
+
+ correctCookie := cookie(from)
+ if hd.cookie != correctCookie {
+ // Incorrect or missing SYN cookie. Send back a handshake
+ // with the expected one.
+ m.write(packet{
+ dst: from,
+ hdr: header{
+ packetType: typeHandshake,
+ flags: flagResponse | flagCookie,
+ connID: hd.connID,
+ timestamp: timestampMicros(),
+ },
+ data: handshakeData{
+ packetSize: uint32(m.packetSize),
+ cookie: correctCookie,
+ }.marshal(),
+ })
+ return
+ }
+
+ seqNo := randomSeqNo()
+
+ m.connsMut.Lock()
+ connID := m.newConnID()
+
+ conn := newConn(m, from)
+ conn.connID = connID
+ conn.remoteConnID = hd.connID
+ conn.nextSeqNo = seqNo + 1
+ conn.nextRecvSeqNo = hdr.sequenceNo + 1
+ conn.packetSize = int(hd.packetSize)
+ if conn.packetSize > m.packetSize {
+ conn.packetSize = m.packetSize
+ }
+ conn.start()
+
+ m.conns[connID] = conn
+ m.connsMut.Unlock()
+
+ m.write(packet{
+ dst: from,
+ hdr: header{
+ packetType: typeHandshake,
+ flags: flagResponse,
+ connID: hd.connID,
+ sequenceNo: seqNo,
+ timestamp: timestampMicros(),
+ },
+ data: handshakeData{
+ connID: conn.connID,
+ packetSize: uint32(conn.packetSize),
+ }.marshal(),
+ })
+
+ m.incoming <- conn
+}
+
+func (m *Mux) incomingHandshakeResponse(from net.Addr, hdr header, data []byte) {
+ m.connsMut.Lock()
+ handShake, ok := m.handshakes[hdr.connID]
+ m.connsMut.Unlock()
+
+ if ok {
+ // This is a response to a handshake in progress.
+ handShake <- packet{
+ dst: nil,
+ hdr: hdr,
+ data: data,
+ }
+ } else if debugMux && hdr.packetType != typeShutdown {
+ log.Printf("Handshake packet %v for unknown conn %v", hdr, hdr.connID)
+ }
+}
+
+func (m *Mux) write(pkt packet) (int, error) {
+ buf := m.buffers.Get().([]byte)
+ buf = buf[:dstHeaderLen+len(pkt.data)]
+ pkt.hdr.marshal(buf)
+ copy(buf[dstHeaderLen:], pkt.data)
+ if debugMux {
+ log.Println(m, "write", pkt)
+ }
+ n, err := m.conn.WriteTo(buf, pkt.dst)
+ m.buffers.Put(buf)
+ return n, err
+}
+
+func (m *Mux) String() string {
+ return fmt.Sprintf("Mux-%v", m.Addr())
+}
+
+// Find a unique connection ID
+func (m *Mux) newConnID() connectionID {
+ for {
+ connID := randomConnID()
+ if _, ok := m.conns[connID]; ok {
+ continue
+ }
+ if _, ok := m.handshakes[connID]; ok {
+ continue
+ }
+ return connID
+ }
+}
+
+func (m *Mux) removeConn(c *Conn) {
+ m.connsMut.Lock()
+ delete(m.conns, c.connID)
+ m.connsMut.Unlock()
+}
diff --git a/core/dst/packetlist.go b/core/dst/packetlist.go
new file mode 100644
index 0000000..6e5ce9a
--- /dev/null
+++ b/core/dst/packetlist.go
@@ -0,0 +1,119 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+type packetList struct {
+ packets []packet
+ slot int
+}
+
+// CutLessSeq cuts packets from the start of the list with sequence numbers
+// lower than seq. Returns the number of packets that were cut.
+func (l *packetList) CutLessSeq(seq sequenceNo) int {
+ var i, cut int
+ for i = range l.packets {
+ if i == l.slot {
+ break
+ }
+ if !l.packets[i].LessSeq(seq) {
+ break
+ }
+ cut++
+ }
+ if cut > 0 {
+ l.Cut(cut)
+ }
+ return cut
+}
+
+func (l *packetList) Cut(n int) {
+ copy(l.packets, l.packets[n:])
+ l.slot -= n
+}
+
+func (l *packetList) Full() bool {
+ return l.slot == len(l.packets)
+}
+
+func (l *packetList) All() []packet {
+ return l.packets[:l.slot]
+}
+
+func (l *packetList) Append(pkt packet) bool {
+ if l.slot == len(l.packets) {
+ return false
+ }
+ l.packets[l.slot] = pkt
+ l.slot++
+ return true
+}
+
+func (l *packetList) AppendAll(pkts []packet) {
+ l.packets = append(l.packets[:l.slot], pkts...)
+ l.slot += len(pkts)
+}
+
+func (l *packetList) Cap() int {
+ return len(l.packets)
+}
+
+func (l *packetList) Len() int {
+ return l.slot
+}
+
+func (l *packetList) Resize(s int) {
+ if s <= cap(l.packets) {
+ l.packets = l.packets[:s]
+ } else {
+ t := make([]packet, s)
+ copy(t, l.packets)
+ l.packets = t
+ }
+}
+
+func (l *packetList) InsertSorted(pkt packet) {
+ for i := range l.packets {
+ if i >= l.slot {
+ l.packets[i] = pkt
+ l.slot++
+ return
+ }
+ if pkt.hdr.sequenceNo == l.packets[i].hdr.sequenceNo {
+ return
+ }
+ if pkt.Less(l.packets[i]) {
+ copy(l.packets[i+1:], l.packets[i:])
+ l.packets[i] = pkt
+ if l.slot < len(l.packets) {
+ l.slot++
+ }
+ return
+ }
+ }
+}
+
+func (l *packetList) LowestSeq() sequenceNo {
+ return l.packets[0].hdr.sequenceNo
+}
+
+func (l *packetList) PopSequence(maxSeq sequenceNo) []packet {
+ highSeq := l.packets[0].hdr.sequenceNo
+ if highSeq >= maxSeq {
+ return nil
+ }
+
+ var i int
+ for i = 1; i < l.slot; i++ {
+ seq := l.packets[i].hdr.sequenceNo
+ if seq != highSeq+1 || seq >= maxSeq {
+ break
+ }
+ highSeq++
+ }
+ pkts := make([]packet, i)
+ copy(pkts, l.packets[:i])
+ l.Cut(i)
+ return pkts
+}
diff --git a/core/dst/packets.go b/core/dst/packets.go
new file mode 100644
index 0000000..c44b8cd
--- /dev/null
+++ b/core/dst/packets.go
@@ -0,0 +1,155 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ "encoding/binary"
+ "fmt"
+ "net"
+)
+
+const dstHeaderLen = 12
+
+type packetType int8
+
+const (
+ typeHandshake packetType = 0x0
+ typeData = 0x1
+ typeAck = 0x2
+ typeNegAck = 0x3
+ typeShutdown = 0x4
+)
+
+func (t packetType) String() string {
+ switch t {
+ case typeData:
+ return "data"
+ case typeHandshake:
+ return "handshake"
+ case typeAck:
+ return "ack"
+ case typeNegAck:
+ return "negAck"
+ case typeShutdown:
+ return "shutdown"
+ default:
+ return "unknown"
+ }
+}
+
+type connectionID uint32
+
+func (c connectionID) String() string {
+ return fmt.Sprintf("Ci%08x", uint32(c))
+}
+
+type sequenceNo uint32
+
+func (s sequenceNo) String() string {
+ return fmt.Sprintf("Sq%d", uint32(s))
+}
+
+type timestamp uint32
+
+func (t timestamp) String() string {
+ return fmt.Sprintf("Ts%d", uint32(t))
+}
+
+const (
+ flagRequest = 1 << 0 // This packet is a handshake request
+ flagResponse = 1 << 1 // This packet is a handshake response
+ flagCookie = 1 << 2 // This packet contains a coookie challenge
+)
+
+type header struct {
+ packetType packetType // 4 bits
+ flags uint8 // 4 bits
+ connID connectionID // 24 bits
+ sequenceNo sequenceNo
+ timestamp timestamp
+}
+
+func (h header) marshal(bs []byte) {
+ binary.BigEndian.PutUint32(bs, uint32(h.connID&0xffffff))
+ bs[0] = h.flags | uint8(h.packetType)<<4
+ binary.BigEndian.PutUint32(bs[4:], uint32(h.sequenceNo))
+ binary.BigEndian.PutUint32(bs[8:], uint32(h.timestamp))
+}
+
+func unmarshalHeader(bs []byte) header {
+ var h header
+ h.packetType = packetType(bs[0] >> 4)
+ h.flags = bs[0] & 0xf
+ h.connID = connectionID(binary.BigEndian.Uint32(bs) & 0xffffff)
+ h.sequenceNo = sequenceNo(binary.BigEndian.Uint32(bs[4:]))
+ h.timestamp = timestamp(binary.BigEndian.Uint32(bs[8:]))
+ return h
+}
+
+func (h header) String() string {
+ return fmt.Sprintf("header{type=%s flags=0x%x connID=%v seq=%v time=%v}", h.packetType, h.flags, h.connID, h.sequenceNo, h.timestamp)
+}
+
+type handshakeData struct {
+ packetSize uint32
+ connID connectionID
+ cookie uint32
+}
+
+func (h handshakeData) marshalInto(data []byte) {
+ binary.BigEndian.PutUint32(data[0:], h.packetSize)
+ binary.BigEndian.PutUint32(data[4:], uint32(h.connID))
+ binary.BigEndian.PutUint32(data[8:], h.cookie)
+}
+
+func (h handshakeData) marshal() []byte {
+ var data [12]byte
+ h.marshalInto(data[:])
+ return data[:]
+}
+
+func unmarshalHandshakeData(data []byte) handshakeData {
+ var h handshakeData
+ h.packetSize = binary.BigEndian.Uint32(data[0:])
+ h.connID = connectionID(binary.BigEndian.Uint32(data[4:]))
+ h.cookie = binary.BigEndian.Uint32(data[8:])
+ return h
+}
+
+func (h handshakeData) String() string {
+ return fmt.Sprintf("handshake{size=%d connID=%v cookie=0x%08x}", h.packetSize, h.connID, h.cookie)
+}
+
+type packet struct {
+ src connectionID
+ dst net.Addr
+ hdr header
+ data []byte
+}
+
+func (p packet) String() string {
+ var dst string
+ if p.dst != nil {
+ dst = "dst=" + p.dst.String() + " "
+ }
+ switch p.hdr.packetType {
+ case typeHandshake:
+ return fmt.Sprintf("%spacket{src=%v %v %v}", dst, p.src, p.hdr, unmarshalHandshakeData(p.data))
+ default:
+ return fmt.Sprintf("%spacket{src=%v %v data[:%d]}", dst, p.src, p.hdr, len(p.data))
+ }
+}
+
+func (p packet) LessSeq(seq sequenceNo) bool {
+ diff := seq - p.hdr.sequenceNo
+ if diff == 0 {
+ return false
+ }
+ return diff < 1<<31
+}
+
+func (a packet) Less(b packet) bool {
+ return a.LessSeq(b.hdr.sequenceNo)
+}
diff --git a/core/dst/sendbuffer.go b/core/dst/sendbuffer.go
new file mode 100644
index 0000000..5e53f17
--- /dev/null
+++ b/core/dst/sendbuffer.go
@@ -0,0 +1,268 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ "fmt"
+ "runtime/debug"
+
+ "sync"
+
+ "github.com/juju/ratelimit"
+)
+
+/*
+ sendWindow
+ v
+ [S|S|S|S|Q|Q|Q|Q| | | | | | | | | ]
+ ^ ^writeSlot
+ sendSlot
+*/
+type sendBuffer struct {
+ mux *Mux // we send packets here
+ scheduler *ratelimit.Bucket // sets send rate for packets
+
+ sendWindow int // maximum number of outstanding non-acked packets
+ packetRate int // target pps
+
+ send packetList // buffered packets
+ sendSlot int // buffer slot from which to send next packet
+
+ lost packetList // list of packets reported lost by timeout
+ lostSlot int // next lost packet to resend
+
+ closed bool
+ closing bool
+ mut sync.Mutex
+ cond *sync.Cond
+}
+
+const (
+ schedulerRate = 1e6
+ schedulerCapacity = schedulerRate / 40
+)
+
+// newSendBuffer creates a new send buffer with a zero window.
+// SetRateAndWindow() must be called to set an initial packet rate and send
+// window before using.
+func newSendBuffer(m *Mux) *sendBuffer {
+ b := &sendBuffer{
+ mux: m,
+ scheduler: ratelimit.NewBucketWithRate(schedulerRate, schedulerCapacity),
+ }
+ b.cond = sync.NewCond(&b.mut)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ b.writerLoop()
+ }()
+ return b
+}
+
+// Write puts a new packet in send buffer and schedules a send. Blocks when
+// the window size is or would be exceeded.
+func (b *sendBuffer) Write(pkt packet) error {
+ b.mut.Lock()
+ defer b.mut.Unlock()
+
+ for b.send.Full() || b.send.Len() >= b.sendWindow {
+ if b.closing {
+ return ErrClosedConn
+ }
+ if debugConnection {
+ log.Println(b, "Write blocked")
+ }
+ b.cond.Wait()
+ }
+ if !b.send.Append(pkt) {
+ panic("bug: append failed")
+ }
+ b.cond.Broadcast()
+ return nil
+}
+
+// Acknowledge removes packets with lower sequence numbers from the loss list
+// or send buffer.
+func (b *sendBuffer) Acknowledge(seq sequenceNo) {
+ b.mut.Lock()
+
+ if cut := b.lost.CutLessSeq(seq); cut > 0 {
+ if debugConnection {
+ log.Println(b, "cut", cut, "from loss list")
+ }
+ // Next resend should always start with the first packet, regardless
+ // of what we might already have resent previously.
+ b.lostSlot = 0
+ b.cond.Broadcast()
+ }
+
+ if cut := b.send.CutLessSeq(seq); cut > 0 {
+ if debugConnection {
+ log.Println(b, "cut", cut, "from send list")
+ }
+ b.sendSlot -= cut
+ b.cond.Broadcast()
+ }
+
+ b.mut.Unlock()
+}
+
+func (b *sendBuffer) NegativeAck(seq sequenceNo) {
+ b.mut.Lock()
+
+ pkts := b.send.PopSequence(seq)
+ if cut := len(pkts); cut > 0 {
+ b.lost.AppendAll(pkts)
+ if debugConnection {
+ log.Println(b, "cut", cut, "from send list, adding to loss list")
+ log.Println(seq, pkts)
+ }
+ b.sendSlot -= cut
+ b.lostSlot = 0
+ b.cond.Broadcast()
+ }
+
+ b.mut.Unlock()
+}
+
+// ScheduleResend arranges for a resend of all currently unacknowledged
+// packets.
+func (b *sendBuffer) ScheduleResend() {
+ b.mut.Lock()
+
+ if b.sendSlot > 0 {
+ // There are packets that have been sent but not acked. Move them from
+ // the send buffer to the loss list for retransmission.
+ if debugConnection {
+ log.Println(b, "scheduled resend from send list", b.sendSlot)
+ }
+
+ // Append the packets to the loss list and rewind the send buffer
+ b.lost.AppendAll(b.send.All()[:b.sendSlot])
+ b.send.Cut(b.sendSlot)
+ b.sendSlot = 0
+ b.cond.Broadcast()
+ }
+
+ if b.lostSlot > 0 {
+ // Also resend whatever was already in the loss list
+ if debugConnection {
+ log.Println(b, "scheduled resend from loss list", b.lostSlot)
+ }
+ b.lostSlot = 0
+ b.cond.Broadcast()
+ }
+
+ b.mut.Unlock()
+}
+
+// SetWindowAndRate sets the window size (in packets) and packet rate (in
+// packets per second) to use when sending.
+func (b *sendBuffer) SetWindowAndRate(sendWindow, packetRate int) {
+ b.mut.Lock()
+ if debugConnection {
+ log.Println(b, "new window & rate", sendWindow, packetRate)
+ }
+ b.packetRate = packetRate
+ b.sendWindow = sendWindow
+ if b.sendWindow > b.send.Cap() {
+ b.send.Resize(b.sendWindow)
+ b.cond.Broadcast()
+ }
+ b.mut.Unlock()
+}
+
+// Stop stops the send buffer from any doing further sending, but waits for
+// the current buffers to be drained.
+func (b *sendBuffer) Stop() {
+ b.mut.Lock()
+
+ if b.closed || b.closing {
+ return
+ }
+
+ b.closing = true
+ for b.lost.Len() > 0 || b.send.Len() > 0 {
+ b.cond.Wait()
+ }
+
+ b.closed = true
+ b.cond.Broadcast()
+ b.mut.Unlock()
+}
+
+// CrashStop stops the send buffer from any doing further sending, without
+// waiting for buffers to drain.
+func (b *sendBuffer) CrashStop() {
+ b.mut.Lock()
+
+ if b.closed || b.closing {
+ return
+ }
+
+ b.closing = true
+ b.closed = true
+ b.cond.Broadcast()
+ b.mut.Unlock()
+}
+
+func (b *sendBuffer) String() string {
+ return fmt.Sprintf("sendBuffer@%p", b)
+}
+
+func (b *sendBuffer) writerLoop() {
+ if debugConnection {
+ log.Println(b, "writer() starting")
+ defer log.Println(b, "writer() exiting")
+ }
+
+ b.scheduler.Take(schedulerCapacity)
+ for {
+ var pkt packet
+ b.mut.Lock()
+ for b.lostSlot >= b.sendWindow ||
+ (b.sendSlot == b.send.Len() && b.lostSlot == b.lost.Len()) {
+ if b.closed {
+ b.mut.Unlock()
+ return
+ }
+
+ if debugConnection {
+ log.Println(b, "writer() paused", b.lostSlot, b.sendSlot, b.sendWindow, b.lost.Len())
+ }
+ b.cond.Wait()
+ }
+
+ if b.lostSlot < b.lost.Len() {
+ pkt = b.lost.All()[b.lostSlot]
+ pkt.hdr.timestamp = timestampMicros()
+ b.lostSlot++
+
+ if debugConnection {
+ log.Println(b, "resend", b.lostSlot, b.lost.Len(), b.sendWindow, pkt.hdr.connID, pkt.hdr.sequenceNo)
+ }
+ } else if b.sendSlot < b.send.Len() {
+ pkt = b.send.All()[b.sendSlot]
+ pkt.hdr.timestamp = timestampMicros()
+ b.sendSlot++
+
+ if debugConnection {
+ log.Println(b, "send", b.sendSlot, b.send.Len(), b.sendWindow, pkt.hdr.connID, pkt.hdr.sequenceNo)
+ }
+ }
+
+ b.cond.Broadcast()
+ packetRate := b.packetRate
+ b.mut.Unlock()
+
+ if pkt.dst != nil {
+ b.scheduler.Wait(schedulerRate / int64(packetRate))
+ b.mux.write(pkt)
+ }
+ }
+}
diff --git a/core/dst/util.go b/core/dst/util.go
new file mode 100644
index 0000000..cdd6d91
--- /dev/null
+++ b/core/dst/util.go
@@ -0,0 +1,29 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ logger "log"
+ "math/rand"
+ "os"
+ "time"
+)
+
+var log = logger.New(os.Stderr, "", logger.LstdFlags)
+
+func SetLogger(l *logger.Logger) {
+ log = l
+}
+func timestampMicros() timestamp {
+ return timestamp(time.Now().UnixNano() / 1000)
+}
+
+func randomSeqNo() sequenceNo {
+ return sequenceNo(rand.Uint32())
+}
+
+func randomConnID() connectionID {
+ return connectionID(rand.Uint32() & 0xffffff)
+}
diff --git a/core/dst/windowcc.go b/core/dst/windowcc.go
new file mode 100644
index 0000000..cc2dd69
--- /dev/null
+++ b/core/dst/windowcc.go
@@ -0,0 +1,144 @@
+// Copyright 2014 The DST Authors. All rights reserved.
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package dst
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "time"
+)
+
+type windowCC struct {
+ minWindow int
+ maxWindow int
+ currentWindow int
+ minRate int
+ maxRate int
+ currentRate int
+ targetRate int
+
+ curRTT time.Duration
+ minRTT time.Duration
+
+ statsFile io.WriteCloser
+ start time.Time
+}
+
+func newWindowCC() *windowCC {
+ var statsFile io.WriteCloser
+
+ if debugCC {
+ statsFile, _ = os.Create(fmt.Sprintf("cc-log-%d.csv", time.Now().Unix()))
+ fmt.Fprintf(statsFile, "ms,minWin,maxWin,curWin,minRate,maxRate,curRate,minRTT,curRTT\n")
+ }
+
+ return &windowCC{
+ minWindow: 1, // Packets
+ maxWindow: 16 << 10,
+ currentWindow: 1,
+
+ minRate: 100, // PPS
+ maxRate: 80e3, // Roughly 1 Gbps at 1500 bytes per packet
+ currentRate: 100,
+ targetRate: 1000,
+
+ minRTT: 10 * time.Second,
+ statsFile: statsFile,
+ start: time.Now(),
+ }
+}
+
+func (w *windowCC) Ack() {
+ if w.curRTT > w.minRTT+100*time.Millisecond {
+ return
+ }
+
+ changed := false
+
+ if w.currentWindow < w.maxWindow {
+ w.currentWindow++
+ changed = true
+ }
+
+ if w.currentRate != w.targetRate {
+ w.currentRate = (w.currentRate*7 + w.targetRate) / 8
+ changed = true
+ }
+
+ if changed && debugCC {
+ w.log()
+ log.Println("Ack", w.currentWindow, w.currentRate)
+ }
+}
+
+func (w *windowCC) NegAck() {
+ if w.currentWindow > w.minWindow {
+ w.currentWindow /= 2
+ }
+ if w.currentRate > w.minRate {
+ w.currentRate /= 2
+ }
+ if debugCC {
+ w.log()
+ log.Println("NegAck", w.currentWindow, w.currentRate)
+ }
+}
+
+func (w *windowCC) Exp() {
+ w.currentWindow = w.minWindow
+ if debugCC {
+ w.log()
+ log.Println("Exp", w.currentWindow, w.currentRate)
+ }
+}
+
+func (w *windowCC) SendWindow() int {
+ if w.currentWindow < w.minWindow {
+ return w.minWindow
+ }
+ if w.currentWindow > w.maxWindow {
+ return w.maxWindow
+ }
+ return w.currentWindow
+}
+
+func (w *windowCC) PacketRate() int {
+ if w.currentRate < w.minRate {
+ return w.minRate
+ }
+ if w.currentRate > w.maxRate {
+ return w.maxRate
+ }
+ return w.currentRate
+}
+
+func (w *windowCC) UpdateRTT(rtt time.Duration) {
+ w.curRTT = rtt
+ if w.curRTT < w.minRTT {
+ w.minRTT = w.curRTT
+ if debugCC {
+ log.Println("Min RTT", w.minRTT)
+ }
+ }
+
+ if w.curRTT > w.minRTT+200*time.Millisecond && w.targetRate > 2*w.minRate {
+ w.targetRate -= w.minRate
+ } else if w.curRTT < w.minRTT+20*time.Millisecond && w.targetRate < w.maxRate {
+ w.targetRate += w.minRate
+ }
+
+ if debugCC {
+ w.log()
+ log.Println("RTT", w.curRTT, "target rate", w.targetRate, "current rate", w.currentRate, "current window", w.currentWindow)
+ }
+}
+
+func (w *windowCC) log() {
+ if w.statsFile == nil {
+ return
+ }
+ fmt.Fprintf(w.statsFile, "%.02f,%d,%d,%d,%d,%d,%d,%.02f,%.02f\n", time.Since(w.start).Seconds()*1000, w.minWindow, w.maxWindow, w.currentWindow, w.minRate, w.maxRate, w.currentRate, w.minRTT.Seconds()*1000, w.curRTT.Seconds()*1000)
+}
diff --git a/core/lib/buf/leakybuf.go b/core/lib/buf/leakybuf.go
new file mode 100644
index 0000000..ee56728
--- /dev/null
+++ b/core/lib/buf/leakybuf.go
@@ -0,0 +1,52 @@
+// Provides leaky buffer, based on the example in Effective Go.
+package buf
+
+type LeakyBuf struct {
+ bufSize int // size of each buffer
+ freeList chan []byte
+}
+
+const LeakyBufSize = 2048 // data.len(2) + hmacsha1(10) + data(4096)
+const maxNBuf = 2048
+
+var LeakyBuffer = NewLeakyBuf(maxNBuf, LeakyBufSize)
+
+func Get() (b []byte) {
+ return LeakyBuffer.Get()
+}
+func Put(b []byte) {
+ LeakyBuffer.Put(b)
+}
+
+// NewLeakyBuf creates a leaky buffer which can hold at most n buffer, each
+// with bufSize bytes.
+func NewLeakyBuf(n, bufSize int) *LeakyBuf {
+ return &LeakyBuf{
+ bufSize: bufSize,
+ freeList: make(chan []byte, n),
+ }
+}
+
+// Get returns a buffer from the leaky buffer or create a new buffer.
+func (lb *LeakyBuf) Get() (b []byte) {
+ select {
+ case b = <-lb.freeList:
+ default:
+ b = make([]byte, lb.bufSize)
+ }
+ return
+}
+
+// Put add the buffer into the free buffer pool for reuse. Panic if the buffer
+// size is not the same with the leaky buffer's. This is intended to expose
+// error usage of leaky buffer.
+func (lb *LeakyBuf) Put(b []byte) {
+ if len(b) != lb.bufSize {
+ panic("invalid buffer size that's put into leaky buffer")
+ }
+ select {
+ case lb.freeList <- b:
+ default:
+ }
+ return
+}
diff --git a/core/lib/ioutils/utils.go b/core/lib/ioutils/utils.go
new file mode 100644
index 0000000..70560cb
--- /dev/null
+++ b/core/lib/ioutils/utils.go
@@ -0,0 +1,68 @@
+package ioutils
+
+import (
+ "io"
+ logger "log"
+
+ lbuf "github.com/snail007/goproxy/core/lib/buf"
+)
+
+func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) {
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ log.Printf("bind crashed %s", err)
+ }
+ }()
+ e1 := make(chan interface{}, 1)
+ e2 := make(chan interface{}, 1)
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ log.Printf("bind crashed %s", err)
+ }
+ }()
+ //_, err := io.Copy(dst, src)
+ err := ioCopy(dst, src)
+ e1 <- err
+ }()
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ log.Printf("bind crashed %s", err)
+ }
+ }()
+ //_, err := io.Copy(src, dst)
+ err := ioCopy(src, dst)
+ e2 <- err
+ }()
+ var err interface{}
+ select {
+ case err = <-e1:
+ //log.Printf("e1")
+ case err = <-e2:
+ //log.Printf("e2")
+ }
+ src.Close()
+ dst.Close()
+ if fn != nil {
+ fn(err)
+ }
+ }()
+}
+func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) {
+ buf := lbuf.LeakyBuffer.Get()
+ defer lbuf.LeakyBuffer.Put(buf)
+ n := 0
+ for {
+ n, err = src.Read(buf)
+ if n > 0 {
+ if _, e := dst.Write(buf[0:n]); e != nil {
+ return e
+ }
+ }
+ if err != nil {
+ return
+ }
+ }
+}
diff --git a/core/lib/kcpcfg/args.go b/core/lib/kcpcfg/args.go
new file mode 100644
index 0000000..5d3b67c
--- /dev/null
+++ b/core/lib/kcpcfg/args.go
@@ -0,0 +1,24 @@
+package kcpcfg
+
+import kcp "github.com/xtaci/kcp-go"
+
+type KCPConfigArgs struct {
+ Key *string
+ Crypt *string
+ Mode *string
+ MTU *int
+ SndWnd *int
+ RcvWnd *int
+ DataShard *int
+ ParityShard *int
+ DSCP *int
+ NoComp *bool
+ AckNodelay *bool
+ NoDelay *int
+ Interval *int
+ Resend *int
+ NoCongestion *int
+ SockBuf *int
+ KeepAlive *int
+ Block kcp.BlockCrypt
+}
diff --git a/utils/map.go b/core/lib/mapx/map.go
similarity index 87%
rename from utils/map.go
rename to core/lib/mapx/map.go
index 8ec82cf..651a818 100644
--- a/utils/map.go
+++ b/core/lib/mapx/map.go
@@ -1,7 +1,9 @@
-package utils
+package mapx
import (
"encoding/json"
+ "fmt"
+ "runtime/debug"
"sync"
)
@@ -150,7 +152,14 @@ type Tuple struct {
func (m ConcurrentMap) Iter() <-chan Tuple {
chans := snapshot(m)
ch := make(chan Tuple)
- go fanIn(chans, ch)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ fanIn(chans, ch)
+ }()
return ch
}
@@ -162,7 +171,14 @@ func (m ConcurrentMap) IterBuffered() <-chan Tuple {
total += cap(c)
}
ch := make(chan Tuple, total)
- go fanIn(chans, ch)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ fanIn(chans, ch)
+ }()
return ch
}
@@ -177,6 +193,11 @@ func snapshot(m ConcurrentMap) (chans []chan Tuple) {
// Foreach shard.
for index, shard := range m {
go func(index int, shard *ConcurrentMapShared) {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
// Foreach key, value pair.
shard.RLock()
chans[index] = make(chan Tuple, len(shard.items))
@@ -197,12 +218,19 @@ func fanIn(chans []chan Tuple, out chan Tuple) {
wg := sync.WaitGroup{}
wg.Add(len(chans))
for _, ch := range chans {
- go func(ch chan Tuple) {
- for t := range ch {
- out <- t
- }
- wg.Done()
- }(ch)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ func(ch chan Tuple) {
+ for t := range ch {
+ out <- t
+ }
+ wg.Done()
+ }(ch)
+ }()
}
wg.Wait()
close(out)
@@ -244,19 +272,31 @@ func (m ConcurrentMap) Keys() []string {
count := m.Count()
ch := make(chan string, count)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
// Foreach shard.
wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT)
for _, shard := range m {
- go func(shard *ConcurrentMapShared) {
- // Foreach key, value pair.
- shard.RLock()
- for key := range shard.items {
- ch <- key
- }
- shard.RUnlock()
- wg.Done()
- }(shard)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ func(shard *ConcurrentMapShared) {
+ // Foreach key, value pair.
+ shard.RLock()
+ for key := range shard.items {
+ ch <- key
+ }
+ shard.RUnlock()
+ wg.Done()
+ }(shard)
+ }()
}
wg.Wait()
close(ch)
diff --git a/core/lib/socks5/socks5.go b/core/lib/socks5/socks5.go
new file mode 100644
index 0000000..e67b5ff
--- /dev/null
+++ b/core/lib/socks5/socks5.go
@@ -0,0 +1,159 @@
+package socks5
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net"
+ "strconv"
+)
+
+const (
+ Method_NO_AUTH = uint8(0x00)
+ Method_GSSAPI = uint8(0x01)
+ Method_USER_PASS = uint8(0x02)
+ Method_IANA = uint8(0x7F)
+ Method_RESVERVE = uint8(0x80)
+ Method_NONE_ACCEPTABLE = uint8(0xFF)
+ VERSION_V5 = uint8(0x05)
+ CMD_CONNECT = uint8(0x01)
+ CMD_BIND = uint8(0x02)
+ CMD_ASSOCIATE = uint8(0x03)
+ ATYP_IPV4 = uint8(0x01)
+ ATYP_DOMAIN = uint8(0x03)
+ ATYP_IPV6 = uint8(0x04)
+ REP_SUCCESS = uint8(0x00)
+ REP_REQ_FAIL = uint8(0x01)
+ REP_RULE_FORBIDDEN = uint8(0x02)
+ REP_NETWOR_UNREACHABLE = uint8(0x03)
+ REP_HOST_UNREACHABLE = uint8(0x04)
+ REP_CONNECTION_REFUSED = uint8(0x05)
+ REP_TTL_TIMEOUT = uint8(0x06)
+ REP_CMD_UNSUPPORTED = uint8(0x07)
+ REP_ATYP_UNSUPPORTED = uint8(0x08)
+ REP_UNKNOWN = uint8(0x09)
+ RSV = uint8(0x00)
+)
+
+var (
+ ZERO_IP = []byte{0x00, 0x00, 0x00, 0x00}
+ ZERO_PORT = []byte{0x00, 0x00}
+)
+var Socks5Errors = []string{
+ "",
+ "general failure",
+ "connection forbidden",
+ "network unreachable",
+ "host unreachable",
+ "connection refused",
+ "TTL expired",
+ "command not supported",
+ "address type not supported",
+}
+
+// Auth contains authentication parameters that specific Dialers may require.
+type UsernamePassword struct {
+ Username, Password string
+}
+
+type PacketUDP struct {
+ rsv uint16
+ frag uint8
+ atype uint8
+ dstHost string
+ dstPort string
+ data []byte
+}
+
+func NewPacketUDP() (p PacketUDP) {
+ return PacketUDP{}
+}
+func (p *PacketUDP) Build(destAddr string, data []byte) (err error) {
+ host, port, err := net.SplitHostPort(destAddr)
+ if err != nil {
+ return
+ }
+ p.rsv = 0
+ p.frag = 0
+ p.dstHost = host
+ p.dstPort = port
+ p.atype = ATYP_IPV4
+ if ip := net.ParseIP(host); ip != nil {
+ if ip4 := ip.To4(); ip4 != nil {
+ p.atype = ATYP_IPV4
+ ip = ip4
+ } else {
+ p.atype = ATYP_IPV6
+ }
+ } else {
+ if len(host) > 255 {
+ err = errors.New("proxy: destination host name too long: " + host)
+ return
+ }
+ p.atype = ATYP_DOMAIN
+ }
+ p.data = data
+
+ return
+}
+func (p *PacketUDP) Parse(b []byte) (err error) {
+ p.frag = uint8(b[2])
+ if p.frag != 0 {
+ err = fmt.Errorf("FRAG only support for 0 , %v ,%v", p.frag, b[:4])
+ return
+ }
+ portIndex := 0
+ p.atype = b[3]
+ switch p.atype {
+ case ATYP_IPV4: //IP V4
+ p.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String()
+ portIndex = 8
+ case ATYP_DOMAIN: //域名
+ domainLen := uint8(b[4])
+ p.dstHost = string(b[5 : 5+domainLen]) //b[4]表示域名的长度
+ portIndex = int(5 + domainLen)
+ case ATYP_IPV6: //IP V6
+ p.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String()
+ portIndex = 20
+ }
+ p.dstPort = strconv.Itoa(int(b[portIndex])<<8 | int(b[portIndex+1]))
+ p.data = b[portIndex+2:]
+ return
+}
+func (p *PacketUDP) Header() []byte {
+ header := new(bytes.Buffer)
+ header.Write([]byte{0x00, 0x00, p.frag, p.atype})
+ if p.atype == ATYP_IPV4 {
+ ip := net.ParseIP(p.dstHost)
+ header.Write(ip.To4())
+ } else if p.atype == ATYP_IPV6 {
+ ip := net.ParseIP(p.dstHost)
+ header.Write(ip.To16())
+ } else if p.atype == ATYP_DOMAIN {
+ hBytes := []byte(p.dstHost)
+ header.WriteByte(byte(len(hBytes)))
+ header.Write(hBytes)
+ }
+ port, _ := strconv.ParseUint(p.dstPort, 10, 64)
+ portBytes := new(bytes.Buffer)
+ binary.Write(portBytes, binary.BigEndian, port)
+ header.Write(portBytes.Bytes()[portBytes.Len()-2:])
+ return header.Bytes()
+}
+func (p *PacketUDP) Bytes() []byte {
+ packBytes := new(bytes.Buffer)
+ packBytes.Write(p.Header())
+ packBytes.Write(p.data)
+ return packBytes.Bytes()
+}
+func (p *PacketUDP) Host() string {
+ return p.dstHost
+}
+
+func (p *PacketUDP) Port() string {
+ return p.dstPort
+}
+func (p *PacketUDP) Data() []byte {
+ return p.data
+}
diff --git a/core/lib/transport/compress.go b/core/lib/transport/compress.go
new file mode 100644
index 0000000..c26767e
--- /dev/null
+++ b/core/lib/transport/compress.go
@@ -0,0 +1,59 @@
+package transport
+
+import (
+ "net"
+ "time"
+
+ "github.com/golang/snappy"
+)
+
+func NewCompStream(conn net.Conn) *CompStream {
+ c := new(CompStream)
+ c.conn = conn
+ c.w = snappy.NewBufferedWriter(conn)
+ c.r = snappy.NewReader(conn)
+ return c
+}
+func NewCompConn(conn net.Conn) net.Conn {
+ c := CompStream{}
+ c.conn = conn
+ c.w = snappy.NewBufferedWriter(conn)
+ c.r = snappy.NewReader(conn)
+ return &c
+}
+
+type CompStream struct {
+ net.Conn
+ conn net.Conn
+ w *snappy.Writer
+ r *snappy.Reader
+}
+
+func (c *CompStream) Read(p []byte) (n int, err error) {
+ return c.r.Read(p)
+}
+
+func (c *CompStream) Write(p []byte) (n int, err error) {
+ n, err = c.w.Write(p)
+ err = c.w.Flush()
+ return n, err
+}
+
+func (c *CompStream) Close() error {
+ return c.conn.Close()
+}
+func (c *CompStream) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+func (c *CompStream) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+func (c *CompStream) SetDeadline(t time.Time) error {
+ return c.conn.SetDeadline(t)
+}
+func (c *CompStream) SetReadDeadline(t time.Time) error {
+ return c.conn.SetReadDeadline(t)
+}
+func (c *CompStream) SetWriteDeadline(t time.Time) error {
+ return c.conn.SetWriteDeadline(t)
+}
diff --git a/core/lib/transport/encrypt/conn.go b/core/lib/transport/encrypt/conn.go
new file mode 100644
index 0000000..db35bd8
--- /dev/null
+++ b/core/lib/transport/encrypt/conn.go
@@ -0,0 +1,40 @@
+package encrypt
+
+import (
+ "crypto/cipher"
+ "io"
+ "net"
+
+ lbuf "github.com/snail007/goproxy/core/lib/buf"
+)
+
+var (
+ lBuf = lbuf.NewLeakyBuf(2048, 2048)
+)
+
+type Conn struct {
+ net.Conn
+ *Cipher
+ w io.Writer
+ r io.Reader
+}
+
+func NewConn(c net.Conn, method, password string) (conn net.Conn, err error) {
+ cipher0, err := NewCipher(method, password)
+ if err != nil {
+ return
+ }
+ conn = &Conn{
+ Conn: c,
+ Cipher: cipher0,
+ r: &cipher.StreamReader{S: cipher0.ReadStream, R: c},
+ w: &cipher.StreamWriter{S: cipher0.WriteStream, W: c},
+ }
+ return
+}
+func (s *Conn) Read(b []byte) (n int, err error) {
+ return s.r.Read(b)
+}
+func (s *Conn) Write(b []byte) (n int, err error) {
+ return s.w.Write(b)
+}
diff --git a/core/lib/transport/encrypt/encrypt.go b/core/lib/transport/encrypt/encrypt.go
new file mode 100644
index 0000000..23b3a3f
--- /dev/null
+++ b/core/lib/transport/encrypt/encrypt.go
@@ -0,0 +1,185 @@
+package encrypt
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/des"
+ "crypto/md5"
+ "crypto/rc4"
+ "crypto/sha256"
+ "errors"
+
+ lbuf "github.com/snail007/goproxy/core/lib/buf"
+ "github.com/Yawning/chacha20"
+ "golang.org/x/crypto/blowfish"
+ "golang.org/x/crypto/cast5"
+)
+
+const leakyBufSize = 2048
+const maxNBuf = 2048
+
+var leakyBuf = lbuf.NewLeakyBuf(maxNBuf, leakyBufSize)
+var errEmptyPassword = errors.New("proxy key")
+
+func md5sum(d []byte) []byte {
+ h := md5.New()
+ h.Write(d)
+ return h.Sum(nil)
+}
+
+func evpBytesToKey(password string, keyLen int) (key []byte) {
+ const md5Len = 16
+ cnt := (keyLen-1)/md5Len + 1
+ m := make([]byte, cnt*md5Len)
+ copy(m, md5sum([]byte(password)))
+
+ // Repeatedly call md5 until bytes generated is enough.
+ // Each call to md5 uses data: prev md5 sum + password.
+ d := make([]byte, md5Len+len(password))
+ start := 0
+ for i := 1; i < cnt; i++ {
+ start += md5Len
+ copy(d, m[start-md5Len:start])
+ copy(d[md5Len:], password)
+ copy(m[start:], md5sum(d))
+ }
+ return m[:keyLen]
+}
+
+type DecOrEnc int
+
+const (
+ Decrypt DecOrEnc = iota
+ Encrypt
+)
+
+func newStream(block cipher.Block, err error, key, iv []byte,
+ doe DecOrEnc) (cipher.Stream, error) {
+ if err != nil {
+ return nil, err
+ }
+ if doe == Encrypt {
+ return cipher.NewCFBEncrypter(block, iv), nil
+ } else {
+ return cipher.NewCFBDecrypter(block, iv), nil
+ }
+}
+
+func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := aes.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+ return cipher.NewCTR(block, iv), nil
+}
+
+func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := des.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newBlowFishStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := blowfish.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newCast5Stream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := cast5.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
+ h := md5.New()
+ h.Write(key)
+ h.Write(iv)
+ rc4key := h.Sum(nil)
+
+ return rc4.NewCipher(rc4key)
+}
+
+func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
+ return chacha20.NewCipher(key, iv)
+}
+
+func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
+ return chacha20.NewCipher(key, iv)
+}
+
+type cipherInfo struct {
+ keyLen int
+ ivLen int
+ newStream func(key, iv []byte, doe DecOrEnc) (cipher.Stream, error)
+}
+
+var cipherMethod = map[string]*cipherInfo{
+ "aes-128-cfb": {16, 16, newAESCFBStream},
+ "aes-192-cfb": {24, 16, newAESCFBStream},
+ "aes-256-cfb": {32, 16, newAESCFBStream},
+ "aes-128-ctr": {16, 16, newAESCTRStream},
+ "aes-192-ctr": {24, 16, newAESCTRStream},
+ "aes-256-ctr": {32, 16, newAESCTRStream},
+ "des-cfb": {8, 8, newDESStream},
+ "bf-cfb": {16, 8, newBlowFishStream},
+ "cast5-cfb": {16, 8, newCast5Stream},
+ "rc4-md5": {16, 16, newRC4MD5Stream},
+ "rc4-md5-6": {16, 6, newRC4MD5Stream},
+ "chacha20": {32, 8, newChaCha20Stream},
+ "chacha20-ietf": {32, 12, newChaCha20IETFStream},
+}
+
+func GetCipherMethods() (keys []string) {
+ keys = []string{}
+ for k := range cipherMethod {
+ keys = append(keys, k)
+ }
+ return
+}
+func CheckCipherMethod(method string) error {
+ if method == "" {
+ method = "aes-256-cfb"
+ }
+ _, ok := cipherMethod[method]
+ if !ok {
+ return errors.New("Unsupported encryption method: " + method)
+ }
+ return nil
+}
+
+type Cipher struct {
+ WriteStream cipher.Stream
+ ReadStream cipher.Stream
+ key []byte
+ info *cipherInfo
+}
+
+func NewCipher(method, password string) (c *Cipher, err error) {
+ if password == "" {
+ return nil, errEmptyPassword
+ }
+ mi, ok := cipherMethod[method]
+ if !ok {
+ return nil, errors.New("Unsupported encryption method: " + method)
+ }
+ key := evpBytesToKey(password, mi.keyLen)
+ c = &Cipher{key: key, info: mi}
+ if err != nil {
+ return nil, err
+ }
+ //hash(key) -> read IV
+ riv := sha256.New().Sum(c.key)[:c.info.ivLen]
+ c.ReadStream, err = c.info.newStream(c.key, riv, Decrypt)
+ if err != nil {
+ return nil, err
+ } //hash(read IV) -> write IV
+ wiv := sha256.New().Sum(riv)[:c.info.ivLen]
+ c.WriteStream, err = c.info.newStream(c.key, wiv, Encrypt)
+ if err != nil {
+ return nil, err
+ }
+ return c, nil
+}
diff --git a/core/lib/udp/udp.go b/core/lib/udp/udp.go
new file mode 100644
index 0000000..cf417c5
--- /dev/null
+++ b/core/lib/udp/udp.go
@@ -0,0 +1,234 @@
+package udputils
+
+import (
+ "fmt"
+ logger "log"
+ "net"
+ "runtime/debug"
+ "strings"
+ "time"
+
+ bufx "github.com/snail007/goproxy/core/lib/buf"
+ mapx "github.com/snail007/goproxy/core/lib/mapx"
+)
+
+type CreateOutUDPConnFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, packet []byte) (outconn *net.UDPConn, err error)
+type CleanFn func(srcAddr string)
+type BeforeSendFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, b []byte) (sendB []byte, err error)
+type BeforeReplyFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, outconn *net.UDPConn, b []byte) (replyB []byte, err error)
+
+type IOBinder struct {
+ outConns mapx.ConcurrentMap
+ listener *net.UDPConn
+ createOutUDPConnFn CreateOutUDPConnFn
+ log *logger.Logger
+ timeout time.Duration
+ cleanFn CleanFn
+ inTCPConn *net.Conn
+ outTCPConn *net.Conn
+ beforeSendFn BeforeSendFn
+ beforeReplyFn BeforeReplyFn
+}
+
+func NewIOBinder(listener *net.UDPConn, log *logger.Logger) *IOBinder {
+ return &IOBinder{
+ listener: listener,
+ outConns: mapx.NewConcurrentMap(),
+ log: log,
+ }
+}
+func (s *IOBinder) Factory(fn CreateOutUDPConnFn) *IOBinder {
+ s.createOutUDPConnFn = fn
+ return s
+}
+func (s *IOBinder) AfterReadFromClient(fn BeforeSendFn) *IOBinder {
+ s.beforeSendFn = fn
+ return s
+}
+func (s *IOBinder) AfterReadFromServer(fn BeforeReplyFn) *IOBinder {
+ s.beforeReplyFn = fn
+ return s
+}
+func (s *IOBinder) Timeout(timeout time.Duration) *IOBinder {
+ s.timeout = timeout
+ return s
+}
+func (s *IOBinder) Clean(fn CleanFn) *IOBinder {
+ s.cleanFn = fn
+ return s
+}
+func (s *IOBinder) AliveWithServeConn(srcAddr string, inTCPConn *net.Conn) *IOBinder {
+ s.inTCPConn = inTCPConn
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ buf := make([]byte, 1)
+ (*inTCPConn).SetReadDeadline(time.Time{})
+ if _, err := (*inTCPConn).Read(buf); err != nil {
+ s.log.Printf("udp related tcp conn of client disconnected with read , %s", err.Error())
+ s.clean(srcAddr)
+ }
+ }()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ for {
+ (*inTCPConn).SetWriteDeadline(time.Now().Add(time.Second * 5))
+ if _, err := (*inTCPConn).Write([]byte{0x00}); err != nil {
+ s.log.Printf("udp related tcp conn of client disconnected with write , %s", err.Error())
+ s.clean(srcAddr)
+ return
+ }
+ (*inTCPConn).SetWriteDeadline(time.Time{})
+ time.Sleep(time.Second * 5)
+ }
+ }()
+ return s
+}
+func (s *IOBinder) AliveWithClientConn(srcAddr string, outTCPConn *net.Conn) *IOBinder {
+ s.outTCPConn = outTCPConn
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ buf := make([]byte, 1)
+ (*outTCPConn).SetReadDeadline(time.Time{})
+ if _, err := (*outTCPConn).Read(buf); err != nil {
+ s.log.Printf("udp related tcp conn to parent disconnected with read , %s", err.Error())
+ s.clean(srcAddr)
+ }
+ }()
+ return s
+}
+func (s *IOBinder) Run() (err error) {
+ var (
+ isClosedErr = func(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "use of closed network connection")
+ }
+ isTimeoutErr = func(err error) bool {
+ if err == nil {
+ return false
+ }
+ e, ok := err.(net.Error)
+ return ok && e.Timeout()
+ }
+ isRefusedErr = func(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "connection refused")
+ }
+ )
+ for {
+ buf := bufx.Get()
+ defer bufx.Put(buf)
+ n, srcAddr, err := s.listener.ReadFromUDP(buf)
+ if err != nil {
+ s.log.Printf("read from client error %s", err)
+ if isClosedErr(err) {
+ return err
+ }
+ continue
+ }
+ var data []byte
+ if s.beforeSendFn != nil {
+ data, err = s.beforeSendFn(s.listener, srcAddr, buf[:n])
+ if err != nil {
+ s.log.Printf("beforeSend retured an error , %s", err)
+ continue
+ }
+ } else {
+ data = buf[:n]
+ }
+ inconnRemoteAddr := srcAddr.String()
+ var outconn *net.UDPConn
+ if v, ok := s.outConns.Get(inconnRemoteAddr); !ok {
+ outconn, err = s.createOutUDPConnFn(s.listener, srcAddr, data)
+ if err != nil {
+ s.log.Printf("connnect fail %s", err)
+ return err
+ }
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ defer func() {
+ s.clean(srcAddr.String())
+ }()
+ buf := bufx.Get()
+ defer bufx.Put(buf)
+ for {
+ if s.timeout > 0 {
+ outconn.SetReadDeadline(time.Now().Add(s.timeout))
+ }
+ n, srcAddr, err := outconn.ReadFromUDP(buf)
+ if err != nil {
+ s.log.Printf("read from remote error %s", err)
+ if isClosedErr(err) || isTimeoutErr(err) || isRefusedErr(err) {
+ return
+ }
+ continue
+ }
+ data := buf[:n]
+ if s.beforeReplyFn != nil {
+ data, err = s.beforeReplyFn(s.listener, srcAddr, outconn, buf[:n])
+ if err != nil {
+ s.log.Printf("beforeReply retured an error , %s", err)
+ continue
+ }
+ }
+ _, err = s.listener.WriteTo(data, srcAddr)
+ if err != nil {
+ s.log.Printf("write to remote error %s", err)
+ if isClosedErr(err) {
+ return
+ }
+ continue
+ }
+ }
+ }()
+ } else {
+ outconn = v.(*net.UDPConn)
+ }
+
+ s.log.Printf("use decrpyted data , %v", data)
+
+ _, err = outconn.Write(data)
+
+ if err != nil {
+ s.log.Printf("write to remote error %s", err)
+ if isClosedErr(err) {
+ return err
+ }
+ }
+ }
+}
+func (s *IOBinder) clean(srcAddr string) *IOBinder {
+ if v, ok := s.outConns.Get(srcAddr); ok {
+ (*v.(*net.UDPConn)).Close()
+ s.outConns.Remove(srcAddr)
+ }
+ if s.inTCPConn != nil {
+ (*s.inTCPConn).Close()
+ }
+ if s.outTCPConn != nil {
+ (*s.outTCPConn).Close()
+ }
+ if s.cleanFn != nil {
+ s.cleanFn(srcAddr)
+ }
+ return s
+}
+
+func (s *IOBinder) Close() {
+ for _, c := range s.outConns.Items() {
+ (*c.(*net.UDPConn)).Close()
+ }
+}
diff --git a/core/proxy/client/proxy.go b/core/proxy/client/proxy.go
new file mode 100644
index 0000000..ffc3209
--- /dev/null
+++ b/core/proxy/client/proxy.go
@@ -0,0 +1,31 @@
+// Package proxy provides support for a variety of protocols to proxy network
+// data.
+package client
+
+import (
+ "net"
+ "time"
+
+ socks5c "github.com/snail007/goproxy/core/lib/socks5"
+ socks5 "github.com/snail007/goproxy/core/proxy/client/socks5"
+)
+
+// A Dialer is a means to establish a connection.
+type Dialer interface {
+ // Dial connects to the given address via the proxy.
+ DialConn(conn *net.Conn, network, addr string) (err error)
+}
+
+// Auth contains authentication parameters that specific Dialers may require.
+type Auth struct {
+ User, Password string
+}
+
+func SOCKS5(timeout time.Duration, auth *Auth) (Dialer, error) {
+ var a *socks5c.UsernamePassword
+ if auth != nil {
+ a = &socks5c.UsernamePassword{auth.User, auth.Password}
+ }
+ d := socks5.NewDialer(a, timeout)
+ return d, nil
+}
diff --git a/core/proxy/client/socks5/socks5.go b/core/proxy/client/socks5/socks5.go
new file mode 100644
index 0000000..61f82a3
--- /dev/null
+++ b/core/proxy/client/socks5/socks5.go
@@ -0,0 +1,263 @@
+package socks5
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "time"
+
+ socks5c "github.com/snail007/goproxy/core/lib/socks5"
+)
+
+type Dialer struct {
+ timeout time.Duration
+ usernamePassword *socks5c.UsernamePassword
+}
+
+// NewDialer returns a new Dialer that dials through the provided
+// proxy server's network and address.
+func NewDialer(auth *socks5c.UsernamePassword, timeout time.Duration) *Dialer {
+ if auth != nil && auth.Password == "" && auth.Username == "" {
+ auth = nil
+ }
+ return &Dialer{
+ usernamePassword: auth,
+ timeout: timeout,
+ }
+}
+
+func (d *Dialer) DialConn(conn *net.Conn, network, addr string) (err error) {
+ client := NewClientConn(conn, network, addr, d.timeout, d.usernamePassword, nil)
+ err = client._Handshake()
+ return
+}
+
+type ClientConn struct {
+ user string
+ password string
+ conn *net.Conn
+ header []byte
+ timeout time.Duration
+ addr string
+ network string
+ udpAddr string
+}
+
+// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
+// with an optional username and password. See RFC 1928 and RFC 1929.
+// target must be a canonical address with a host and port.
+// network : tcp udp
+func NewClientConn(conn *net.Conn, network, target string, timeout time.Duration, auth *socks5c.UsernamePassword, header []byte) *ClientConn {
+ s := &ClientConn{
+ conn: conn,
+ network: network,
+ timeout: timeout,
+ }
+ if auth != nil {
+ s.user = auth.Username
+ s.password = auth.Password
+ }
+ if header != nil && len(header) > 0 {
+ s.header = header
+ }
+ if network == "udp" && target == "" {
+ target = "0.0.0.0:1"
+ }
+ s.addr = target
+ return s
+}
+
+// connect takes an existing connection to a socks5 proxy server,
+// and commands the server to extend that connection to target,
+// which must be a canonical address with a host and port.
+func (s *ClientConn) _Handshake() error {
+ host, portStr, err := net.SplitHostPort(s.addr)
+ if err != nil {
+ return err
+ }
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return errors.New("proxy: failed to parse port number: " + portStr)
+ }
+ if port < 1 || port > 0xffff {
+ return errors.New("proxy: port number out of range: " + portStr)
+ }
+
+ if err := s.auth(host); err != nil {
+ return err
+ }
+ buf := []byte{}
+ if s.network == "tcp" {
+ buf = append(buf, socks5c.VERSION_V5, socks5c.CMD_CONNECT, 0 /* reserved */)
+
+ } else {
+ buf = append(buf, socks5c.VERSION_V5, socks5c.CMD_ASSOCIATE, 0 /* reserved */)
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ if ip4 := ip.To4(); ip4 != nil {
+ buf = append(buf, socks5c.ATYP_IPV4)
+ ip = ip4
+ } else {
+ buf = append(buf, socks5c.ATYP_IPV6)
+ }
+ buf = append(buf, ip...)
+ } else {
+ if len(host) > 255 {
+ return errors.New("proxy: destination host name too long: " + host)
+ }
+ buf = append(buf, socks5c.ATYP_DOMAIN)
+ buf = append(buf, byte(len(host)))
+ buf = append(buf, host...)
+ }
+ buf = append(buf, byte(port>>8), byte(port))
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := (*s.conn).Write(buf); err != nil {
+ return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ (*s.conn).SetDeadline(time.Time{})
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := io.ReadFull((*s.conn), buf[:4]); err != nil {
+ return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ (*s.conn).SetDeadline(time.Time{})
+ failure := "unknown error"
+ if int(buf[1]) < len(socks5c.Socks5Errors) {
+ failure = socks5c.Socks5Errors[buf[1]]
+ }
+
+ if len(failure) > 0 {
+ return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
+ }
+
+ bytesToDiscard := 0
+ switch buf[3] {
+ case socks5c.ATYP_IPV4:
+ bytesToDiscard = net.IPv4len
+ case socks5c.ATYP_IPV6:
+ bytesToDiscard = net.IPv6len
+ case socks5c.ATYP_DOMAIN:
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ _, err := io.ReadFull((*s.conn), buf[:1])
+ (*s.conn).SetDeadline(time.Time{})
+ if err != nil {
+ return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ bytesToDiscard = int(buf[0])
+ default:
+ return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
+ }
+
+ if cap(buf) < bytesToDiscard {
+ buf = make([]byte, bytesToDiscard)
+ } else {
+ buf = buf[:bytesToDiscard]
+ }
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := io.ReadFull((*s.conn), buf); err != nil {
+ return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ (*s.conn).SetDeadline(time.Time{})
+ var ip net.IP
+ ip = buf
+ ipStr := ""
+ if bytesToDiscard == net.IPv4len || bytesToDiscard == net.IPv6len {
+ if ipv4 := ip.To4(); ipv4 != nil {
+ ipStr = ipv4.String()
+ } else {
+ ipStr = ip.To16().String()
+ }
+ }
+ //log.Printf("%v", ipStr)
+ // Also need to discard the port number
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil {
+ return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ p := binary.BigEndian.Uint16([]byte{buf[0], buf[1]})
+ //log.Printf("%v", p)
+ s.udpAddr = net.JoinHostPort(ipStr, fmt.Sprintf("%d", p))
+ //log.Printf("%v", s.udpAddr)
+ (*s.conn).SetDeadline(time.Time{})
+ return nil
+}
+func (s *ClientConn) SendUDP(data []byte, addr string) (respData []byte, err error) {
+
+ c, err := net.DialTimeout("udp", s.udpAddr, s.timeout)
+ if err != nil {
+ return
+ }
+ conn := c.(*net.UDPConn)
+
+ p := socks5c.NewPacketUDP()
+ p.Build(addr, data)
+ conn.SetDeadline(time.Now().Add(s.timeout))
+ conn.Write(p.Bytes())
+ conn.SetDeadline(time.Time{})
+
+ buf := make([]byte, 1024)
+ conn.SetDeadline(time.Now().Add(s.timeout))
+ n, _, err := conn.ReadFrom(buf)
+ conn.SetDeadline(time.Time{})
+ if err != nil {
+ return
+ }
+ respData = buf[:n]
+ return
+}
+func (s *ClientConn) auth(host string) error {
+
+ // the size here is just an estimate
+ buf := make([]byte, 0, 6+len(host))
+
+ buf = append(buf, socks5c.VERSION_V5)
+ if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
+ buf = append(buf, 2 /* num auth methods */, socks5c.Method_NO_AUTH, socks5c.Method_USER_PASS)
+ } else {
+ buf = append(buf, 1 /* num auth methods */, socks5c.Method_NO_AUTH)
+ }
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := (*s.conn).Write(buf); err != nil {
+ return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ (*s.conn).SetDeadline(time.Time{})
+
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil {
+ return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ (*s.conn).SetDeadline(time.Time{})
+
+ if buf[0] != 5 {
+ return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
+ }
+ if buf[1] == 0xff {
+ return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
+ }
+
+ // See RFC 1929
+ if buf[1] == socks5c.Method_USER_PASS {
+ buf = buf[:0]
+ buf = append(buf, 1 /* password protocol version */)
+ buf = append(buf, uint8(len(s.user)))
+ buf = append(buf, s.user...)
+ buf = append(buf, uint8(len(s.password)))
+ buf = append(buf, s.password...)
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := (*s.conn).Write(buf); err != nil {
+ return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ (*s.conn).SetDeadline(time.Time{})
+ (*s.conn).SetDeadline(time.Now().Add(s.timeout))
+ if _, err := io.ReadFull((*s.conn), buf[:2]); err != nil {
+ return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+ }
+ (*s.conn).SetDeadline(time.Time{})
+ if buf[1] != 0 {
+ return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
+ }
+ }
+ return nil
+}
diff --git a/core/proxy/client/tests/proxy_test.go b/core/proxy/client/tests/proxy_test.go
new file mode 100644
index 0000000..28f8645
--- /dev/null
+++ b/core/proxy/client/tests/proxy_test.go
@@ -0,0 +1,79 @@
+package tests
+
+import (
+ "io/ioutil"
+ "net"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ proxyclient "github.com/snail007/goproxy/core/proxy/client"
+ sdk "github.com/snail007/goproxy/sdk/android-ios"
+)
+
+func TestSocks5(t *testing.T) {
+ estr := sdk.Start("s1", "socks -p :8185 --log test.log")
+ if estr != "" {
+ t.Fatal(estr)
+ }
+ p, e := proxyclient.SOCKS5(time.Second, nil)
+ if e != nil {
+ t.Error(e)
+ } else {
+ c, e := net.Dial("tcp", "127.0.0.1:8185")
+ if e != nil {
+ t.Fatal(e)
+ }
+ e = p.DialConn(&c, "tcp", "www.baidu.com:80")
+ if e != nil {
+ t.Fatal(e)
+ }
+ _, e = c.Write([]byte("Get / http/1.1\r\nHost: www.baidu.com\r\n"))
+ if e != nil {
+ t.Fatal(e)
+ }
+ b, e := ioutil.ReadAll(c)
+ if e != nil {
+ t.Fatal(e)
+ }
+ if !strings.HasPrefix(string(b), "HTTP") {
+ t.Fatalf("request baidu fail:%s", string(b))
+ }
+ }
+ sdk.Stop("s1")
+ os.Remove("test.log")
+}
+
+func TestSocks5Auth(t *testing.T) {
+ estr := sdk.Start("s1", "socks -p :8185 -a u:p --log test.log")
+ if estr != "" {
+ t.Fatal(estr)
+ }
+ p, e := proxyclient.SOCKS5(time.Second, &proxyclient.Auth{User: "u", Password: "p"})
+ if e != nil {
+ t.Error(e)
+ } else {
+ c, e := net.Dial("tcp", "127.0.0.1:8185")
+ if e != nil {
+ t.Fatal(e)
+ }
+ e = p.DialConn(&c, "tcp", "www.baidu.com:80")
+ if e != nil {
+ t.Fatal(e)
+ }
+ _, e = c.Write([]byte("Get / http/1.1\r\nHost: www.baidu.com\r\n"))
+ if e != nil {
+ t.Fatal(e)
+ }
+ b, e := ioutil.ReadAll(c)
+ if e != nil {
+ t.Fatal(e)
+ }
+ if !strings.HasPrefix(string(b), "HTTP") {
+ t.Fatalf("request baidu fail:%s", string(b))
+ }
+ }
+ sdk.Stop("s1")
+ os.Remove("test.log")
+}
diff --git a/core/proxy/server/socks5/server.go b/core/proxy/server/socks5/server.go
new file mode 100644
index 0000000..fff8d03
--- /dev/null
+++ b/core/proxy/server/socks5/server.go
@@ -0,0 +1,373 @@
+package socks5
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ socks5c "github.com/snail007/goproxy/core/proxy/common/socks5"
+)
+
+type BasicAuther interface {
+ CheckUserPass(username, password, fromIP, ToTarget string) bool
+}
+type Request struct {
+ ver uint8
+ cmd uint8
+ reserve uint8
+ addressType uint8
+ dstAddr string
+ dstPort string
+ dstHost string
+ bytes []byte
+ rw io.ReadWriter
+}
+
+func NewRequest(rw io.ReadWriter, header ...[]byte) (req Request, err interface{}) {
+ var b = make([]byte, 1024)
+ var n int
+ req = Request{rw: rw}
+ if header != nil && len(header) == 1 && len(header[0]) > 1 {
+ b = header[0]
+ n = len(header[0])
+ } else {
+ n, err = rw.Read(b[:])
+ if err != nil {
+ err = fmt.Errorf("read req data fail,ERR: %s", err)
+ return
+ }
+ }
+ req.ver = uint8(b[0])
+ req.cmd = uint8(b[1])
+ req.reserve = uint8(b[2])
+ req.addressType = uint8(b[3])
+ if b[0] != 0x5 {
+ err = fmt.Errorf("sosck version supported")
+ req.TCPReply(socks5c.REP_REQ_FAIL)
+ return
+ }
+ switch b[3] {
+ case 0x01: //IP V4
+ req.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String()
+ case 0x03: //域名
+ req.dstHost = string(b[5 : n-2]) //b[4]表示域名的长度
+ case 0x04: //IP V6
+ req.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String()
+ }
+ req.dstPort = strconv.Itoa(int(b[n-2])<<8 | int(b[n-1]))
+ req.dstAddr = net.JoinHostPort(req.dstHost, req.dstPort)
+ req.bytes = b[:n]
+ return
+}
+func (s *Request) Bytes() []byte {
+ return s.bytes
+}
+func (s *Request) Addr() string {
+ return s.dstAddr
+}
+func (s *Request) Host() string {
+ return s.dstHost
+}
+func (s *Request) Port() string {
+ return s.dstPort
+}
+func (s *Request) AType() uint8 {
+ return s.addressType
+}
+func (s *Request) CMD() uint8 {
+ return s.cmd
+}
+
+func (s *Request) TCPReply(rep uint8) (err error) {
+ _, err = s.rw.Write(s.NewReply(rep, "0.0.0.0:0"))
+ return
+}
+func (s *Request) UDPReply(rep uint8, addr string) (err error) {
+ _, err = s.rw.Write(s.NewReply(rep, addr))
+ return
+}
+func (s *Request) NewReply(rep uint8, addr string) []byte {
+ var response bytes.Buffer
+ host, port, _ := net.SplitHostPort(addr)
+ ip := net.ParseIP(host)
+ ipb := ip.To4()
+ atyp := socks5c.ATYP_IPV4
+ ipv6 := ip.To16()
+ zeroiIPv6 := fmt.Sprintf("%d%d%d%d%d%d%d%d%d%d%d%d",
+ ipv6[0], ipv6[1], ipv6[2], ipv6[3],
+ ipv6[4], ipv6[5], ipv6[6], ipv6[7],
+ ipv6[8], ipv6[9], ipv6[10], ipv6[11],
+ )
+ if ipb == nil && ipv6 != nil && "0000000000255255" != zeroiIPv6 {
+ atyp = socks5c.ATYP_IPV6
+ ipb = ip.To16()
+ }
+ porti, _ := strconv.Atoi(port)
+ portb := make([]byte, 2)
+ binary.BigEndian.PutUint16(portb, uint16(porti))
+ // log.Printf("atyp : %v", atyp)
+ // log.Printf("ip : %v", []byte(ip))
+ response.WriteByte(socks5c.VERSION_V5)
+ response.WriteByte(rep)
+ response.WriteByte(socks5c.RSV)
+ response.WriteByte(atyp)
+ response.Write(ipb)
+ response.Write(portb)
+ return response.Bytes()
+}
+
+type MethodsRequest struct {
+ ver uint8
+ methodsCount uint8
+ methods []uint8
+ bytes []byte
+ rw *io.ReadWriter
+}
+
+func NewMethodsRequest(r io.ReadWriter, header ...[]byte) (s MethodsRequest, err interface{}) {
+ defer func() {
+ if err == nil {
+ err = recover()
+ }
+ }()
+ s = MethodsRequest{}
+ s.rw = &r
+ var buf = make([]byte, 300)
+ var n int
+ if header != nil && len(header) == 1 && len(header[0]) > 1 {
+ buf = header[0]
+ n = len(header[0])
+ } else {
+ n, err = r.Read(buf)
+ if err != nil {
+ return
+ }
+ }
+ if buf[0] != 0x05 {
+ err = fmt.Errorf("socks version not supported")
+ return
+ }
+ if n != int(buf[1])+int(2) {
+ err = fmt.Errorf("socks methods data length error")
+ return
+ }
+ s.ver = buf[0]
+ s.methodsCount = buf[1]
+ s.methods = buf[2:n]
+ s.bytes = buf[:n]
+ return
+}
+func (s *MethodsRequest) Version() uint8 {
+ return s.ver
+}
+func (s *MethodsRequest) MethodsCount() uint8 {
+ return s.methodsCount
+}
+func (s *MethodsRequest) Methods() []uint8 {
+ return s.methods
+}
+func (s *MethodsRequest) Select(method uint8) bool {
+ for _, m := range s.methods {
+ if m == method {
+ return true
+ }
+ }
+ return false
+}
+func (s *MethodsRequest) Reply(method uint8) (err error) {
+ _, err = (*s.rw).Write([]byte{byte(socks5c.VERSION_V5), byte(method)})
+ return
+}
+func (s *MethodsRequest) Bytes() []byte {
+ return s.bytes
+}
+
+type ServerConn struct {
+ target string
+ user string
+ password string
+ conn *net.Conn
+ timeout time.Duration
+ auth *BasicAuther
+ header []byte
+ ver uint8
+ //method
+ methodsCount uint8
+ methods []uint8
+ method uint8
+ //request
+ cmd uint8
+ reserve uint8
+ addressType uint8
+ dstAddr string
+ dstPort string
+ dstHost string
+ udpAddress string
+}
+
+func NewServerConn(conn *net.Conn, timeout time.Duration, auth *BasicAuther, udpAddress string, header []byte) *ServerConn {
+ if udpAddress == "" {
+ udpAddress = "0.0.0.0:16666"
+ }
+ s := &ServerConn{
+ conn: conn,
+ timeout: timeout,
+ auth: auth,
+ header: header,
+ ver: socks5c.VERSION_V5,
+ udpAddress: udpAddress,
+ }
+ return s
+
+}
+func (s *ServerConn) Close() {
+ (*s.conn).Close()
+}
+func (s *ServerConn) AuthData() socks5c.UsernamePassword {
+ return socks5c.UsernamePassword{s.user, s.password}
+}
+func (s *ServerConn) Method() uint8 {
+ return s.method
+}
+func (s *ServerConn) Target() string {
+ return s.target
+}
+func (s *ServerConn) Handshake() (err error) {
+ remoteAddr := (*s.conn).RemoteAddr()
+ //协商开始
+ //method select request
+ var methodReq MethodsRequest
+ (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+
+ methodReq, e := NewMethodsRequest((*s.conn), s.header)
+ (*s.conn).SetReadDeadline(time.Time{})
+ if e != nil {
+ (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+ methodReq.Reply(socks5c.Method_NONE_ACCEPTABLE)
+ (*s.conn).SetReadDeadline(time.Time{})
+ err = fmt.Errorf("new methods request fail,ERR: %s", e)
+ return
+ }
+ //log.Printf("%v,s.auth == %v && methodReq.Select(Method_NO_AUTH) %v", methodReq.methods, s.auth, methodReq.Select(Method_NO_AUTH))
+ if s.auth == nil && methodReq.Select(socks5c.Method_NO_AUTH) && !methodReq.Select(socks5c.Method_USER_PASS) {
+ // if !methodReq.Select(Method_NO_AUTH) {
+ // (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+ // methodReq.Reply(Method_NONE_ACCEPTABLE)
+ // (*s.conn).SetReadDeadline(time.Time{})
+ // err = fmt.Errorf("none method found : Method_NO_AUTH")
+ // return
+ // }
+ s.method = socks5c.Method_NO_AUTH
+ //method select reply
+ (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+ err = methodReq.Reply(socks5c.Method_NO_AUTH)
+ (*s.conn).SetReadDeadline(time.Time{})
+ if err != nil {
+ err = fmt.Errorf("reply answer data fail,ERR: %s", err)
+ return
+ }
+ // err = fmt.Errorf("% x", methodReq.Bytes())
+ } else {
+ //auth
+ if !methodReq.Select(socks5c.Method_USER_PASS) {
+ (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+ methodReq.Reply(socks5c.Method_NONE_ACCEPTABLE)
+ (*s.conn).SetReadDeadline(time.Time{})
+ err = fmt.Errorf("none method found : Method_USER_PASS")
+ return
+ }
+ s.method = socks5c.Method_USER_PASS
+ //method reply need auth
+ (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+ err = methodReq.Reply(socks5c.Method_USER_PASS)
+ (*s.conn).SetReadDeadline(time.Time{})
+ if err != nil {
+ err = fmt.Errorf("reply answer data fail,ERR: %s", err)
+ return
+ }
+ //read auth
+ buf := make([]byte, 500)
+ var n int
+ (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+ n, err = (*s.conn).Read(buf)
+ (*s.conn).SetReadDeadline(time.Time{})
+ if err != nil {
+ err = fmt.Errorf("read auth info fail,ERR: %s", err)
+ return
+ }
+ r := buf[:n]
+ s.user = string(r[2 : r[1]+2])
+ s.password = string(r[2+r[1]+1:])
+ //err = fmt.Errorf("user:%s,pass:%s", user, pass)
+ //auth
+ _addr := strings.Split(remoteAddr.String(), ":")
+ if s.auth == nil || (*s.auth).CheckUserPass(s.user, s.password, _addr[0], "") {
+ (*s.conn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(s.timeout)))
+ _, err = (*s.conn).Write([]byte{0x01, 0x00})
+ (*s.conn).SetDeadline(time.Time{})
+ if err != nil {
+ err = fmt.Errorf("answer auth success to %s fail,ERR: %s", remoteAddr, err)
+ return
+ }
+ } else {
+ (*s.conn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(s.timeout)))
+ _, err = (*s.conn).Write([]byte{0x01, 0x01})
+ (*s.conn).SetDeadline(time.Time{})
+ if err != nil {
+ err = fmt.Errorf("answer auth fail to %s fail,ERR: %s", remoteAddr, err)
+ return
+ }
+ err = fmt.Errorf("auth fail from %s", remoteAddr)
+ return
+ }
+ }
+ //request detail
+ (*s.conn).SetReadDeadline(time.Now().Add(time.Second * s.timeout))
+ request, e := NewRequest(*s.conn)
+ (*s.conn).SetReadDeadline(time.Time{})
+ if e != nil {
+ err = fmt.Errorf("read request data fail,ERR: %s", e)
+ return
+ }
+ //协商结束
+
+ switch request.CMD() {
+ case socks5c.CMD_BIND:
+ err = request.TCPReply(socks5c.REP_UNKNOWN)
+ if err != nil {
+ err = fmt.Errorf("TCPReply REP_UNKNOWN to %s fail,ERR: %s", remoteAddr, err)
+ return
+ }
+ err = fmt.Errorf("cmd bind not supported, form: %s", remoteAddr)
+ return
+ case socks5c.CMD_CONNECT:
+ err = request.TCPReply(socks5c.REP_SUCCESS)
+ if err != nil {
+ err = fmt.Errorf("TCPReply REP_SUCCESS to %s fail,ERR: %s", remoteAddr, err)
+ return
+ }
+ case socks5c.CMD_ASSOCIATE:
+ err = request.UDPReply(socks5c.REP_SUCCESS, s.udpAddress)
+ if err != nil {
+ err = fmt.Errorf("UDPReply REP_SUCCESS to %s fail,ERR: %s", remoteAddr, err)
+ return
+ }
+ }
+
+ //fill socks info
+ s.target = request.Addr()
+ s.methodsCount = methodReq.MethodsCount()
+ s.methods = methodReq.Methods()
+ s.cmd = request.CMD()
+ s.reserve = request.reserve
+ s.addressType = request.addressType
+ s.dstAddr = request.dstAddr
+ s.dstHost = request.dstHost
+ s.dstPort = request.dstPort
+ return
+}
diff --git a/core/tproxy/README.md b/core/tproxy/README.md
new file mode 100644
index 0000000..13de09b
--- /dev/null
+++ b/core/tproxy/README.md
@@ -0,0 +1,35 @@
+# 透传用户IP手册
+
+说明:
+
+通过Linux的TPROXY功能,可以实现源站服务程序可以看见客户端真实IP,实现该功能需要linux操作系统和程序都要满足一定的条件.
+
+环境要求:
+
+源站必须是运行在Linux上面的服务程序,同时Linux需要满足下面条件:
+
+1.Linux内核版本 >= 2.6.28
+
+2.判断系统是否支持TPROXY,执行:
+
+ grep TPROXY /boot/config-`uname -r`
+
+ 如果输出有下面的结果说明支持.
+
+ CONFIG_NETFILTER_XT_TARGET_TPROXY=m
+
+部署步骤:
+
+1.在源站的linux系统里面每次开机启动都要用root权限执行tproxy环境设置脚本:tproxy_setup.sh
+
+2.在源站的linux系统里面使用root权限执行代理proxy
+
+参数 -tproxy 是开启代理的tproxy功能.
+
+./proxy -tproxy
+
+2.源站的程序监听的地址IP需要使用:127.0.1.1
+
+比如源站以前监听的地址是: 0.0.0.0:8800 , 现在需要修改为:127.0.1.1:8800
+
+3.转发规则里面源站地址必须是对应的,比如上面的:127.0.1.1:8800
diff --git a/core/tproxy/tproxy.go b/core/tproxy/tproxy.go
new file mode 100644
index 0000000..ca754c2
--- /dev/null
+++ b/core/tproxy/tproxy.go
@@ -0,0 +1,249 @@
+// Package tproxy provides the TCPDial and TCPListen tproxy equivalent of the
+// net package Dial and Listen with tproxy support for linux ONLY.
+package tproxy
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "time"
+
+ "golang.org/x/sys/unix"
+)
+
+const big = 0xFFFFFF
+const IP_ORIGADDRS = 20
+
+// Debug outs the library in Debug mode
+var Debug = false
+
+func ipToSocksAddr(family int, ip net.IP, port int, zone string) (unix.Sockaddr, error) {
+ switch family {
+ case unix.AF_INET:
+ if len(ip) == 0 {
+ ip = net.IPv4zero
+ }
+ if ip = ip.To4(); ip == nil {
+ return nil, net.InvalidAddrError("non-IPv4 address")
+ }
+ sa := new(unix.SockaddrInet4)
+ for i := 0; i < net.IPv4len; i++ {
+ sa.Addr[i] = ip[i]
+ }
+ sa.Port = port
+ return sa, nil
+ case unix.AF_INET6:
+ if len(ip) == 0 {
+ ip = net.IPv6zero
+ }
+ // IPv4 callers use 0.0.0.0 to mean "announce on any available address".
+ // In IPv6 mode, Linux treats that as meaning "announce on 0.0.0.0",
+ // which it refuses to do. Rewrite to the IPv6 unspecified address.
+ if ip.Equal(net.IPv4zero) {
+ ip = net.IPv6zero
+ }
+ if ip = ip.To16(); ip == nil {
+ return nil, net.InvalidAddrError("non-IPv6 address")
+ }
+ sa := new(unix.SockaddrInet6)
+ for i := 0; i < net.IPv6len; i++ {
+ sa.Addr[i] = ip[i]
+ }
+ sa.Port = port
+ sa.ZoneId = uint32(zoneToInt(zone))
+ return sa, nil
+ }
+ return nil, net.InvalidAddrError("unexpected socket family")
+}
+
+func zoneToInt(zone string) int {
+ if zone == "" {
+ return 0
+ }
+ if ifi, err := net.InterfaceByName(zone); err == nil {
+ return ifi.Index
+ }
+ n, _, _ := dtoi(zone, 0)
+ return n
+}
+
+func dtoi(s string, i0 int) (n int, i int, ok bool) {
+ n = 0
+ for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
+ n = n*10 + int(s[i]-'0')
+ if n >= big {
+ return 0, i, false
+ }
+ }
+ if i == i0 {
+ return 0, i, false
+ }
+ return n, i, true
+}
+
+// IPTcpAddrToUnixSocksAddr ---
+func IPTcpAddrToUnixSocksAddr(addr string) (sa unix.Sockaddr, err error) {
+ if Debug {
+ fmt.Println("DEBUG: IPTcpAddrToUnixSocksAddr recieved address:", addr)
+ }
+ addressNet := "tcp6"
+ if addr[0] != '[' {
+ addressNet = "tcp4"
+ }
+ tcpAddr, err := net.ResolveTCPAddr(addressNet, addr)
+ if err != nil {
+ return nil, err
+ }
+ return ipToSocksAddr(ipType(addr), tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone)
+}
+
+// IPv6UdpAddrToUnixSocksAddr ---
+func IPv6UdpAddrToUnixSocksAddr(addr string) (sa unix.Sockaddr, err error) {
+ tcpAddr, err := net.ResolveTCPAddr("udp6", addr)
+ if err != nil {
+ return nil, err
+ }
+ return ipToSocksAddr(unix.AF_INET6, tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone)
+}
+
+// TCPListen is listening for incoming IP packets which are being intercepted.
+// In conflict to regular Listen mehtod the socket destination and source addresses
+// are of the intercepted connection.
+// Else then that it works exactly like net package net.Listen.
+func TCPListen(listenAddr string) (listener net.Listener, err error) {
+ s, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer unix.Close(s)
+ err = unix.SetsockoptInt(s, unix.SOL_IP, unix.IP_TRANSPARENT, 1)
+ if err != nil {
+ return nil, err
+ }
+
+ sa, err := IPTcpAddrToUnixSocksAddr(listenAddr)
+ if err != nil {
+ return nil, err
+ }
+ err = unix.Bind(s, sa)
+ if err != nil {
+ return nil, err
+ }
+ err = unix.Listen(s, unix.SOMAXCONN)
+ if err != nil {
+ return nil, err
+ }
+ f := os.NewFile(uintptr(s), "TProxy")
+ defer f.Close()
+ return net.FileListener(f)
+}
+func ipType(localAddr string) int {
+ host, _, _ := net.SplitHostPort(localAddr)
+ if host != "" {
+ ip := net.ParseIP(host)
+ if ip == nil || ip.To4() != nil {
+ return unix.AF_INET
+ }
+ return unix.AF_INET6
+ }
+ return unix.AF_INET
+}
+
+// TCPDial is a special tcp connection which binds a non local address as the source.
+// Except then the option to bind to a specific local address which the machine doesn't posses
+// it is exactly like any other net.Conn connection.
+// It is advised to use port numbered 0 in the localAddr and leave the kernel to choose which
+// Local port to use in order to avoid errors and binding conflicts.
+func TCPDial(localAddr, remoteAddr string, timeout time.Duration) (conn net.Conn, err error) {
+ timer := time.NewTimer(timeout)
+ defer timer.Stop()
+ if Debug {
+ fmt.Println("TCPDial from:", localAddr, "to:", remoteAddr)
+ }
+ s, err := unix.Socket(ipType(localAddr), unix.SOCK_STREAM, 0)
+
+ //In a case there was a need for a non-blocking socket an example
+ //s, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM |unix.SOCK_NONBLOCK, 0)
+ if err != nil {
+ fmt.Println(err)
+ return nil, err
+ }
+ defer unix.Close(s)
+ err = unix.SetsockoptInt(s, unix.SOL_IP, unix.IP_TRANSPARENT, 1)
+ if err != nil {
+ if Debug {
+ fmt.Println("ERROR setting the socket in IP_TRANSPARENT mode", err)
+ }
+
+ return nil, err
+ }
+ err = unix.SetsockoptInt(s, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
+ if err != nil {
+ if Debug {
+ fmt.Println("ERROR setting the socket in unix.SO_REUSEADDR mode", err)
+ }
+ return nil, err
+ }
+
+ rhost, _, err := net.SplitHostPort(localAddr)
+ if err != nil {
+ if Debug {
+ // fmt.Fprintln(os.Stderr, err)
+ fmt.Println("ERROR", err, "running net.SplitHostPort on address:", localAddr)
+ }
+ }
+
+ sa, err := IPTcpAddrToUnixSocksAddr(rhost + ":0")
+ if err != nil {
+ if Debug {
+ fmt.Println("ERROR creating a hostaddres for the socker with IPTcpAddrToUnixSocksAddr", err)
+ }
+ return nil, err
+ }
+
+ remoteSocket, err := IPTcpAddrToUnixSocksAddr(remoteAddr)
+ if err != nil {
+ if Debug {
+ fmt.Println("ERROR creating a remoteSocket for the socker with IPTcpAddrToUnixSocksAddr on the remote addres", err)
+ }
+ return nil, err
+ }
+
+ err = unix.Bind(s, sa)
+ if err != nil {
+ fmt.Println(err)
+ return nil, err
+ }
+
+ errChn := make(chan error, 1)
+ func() {
+ err = unix.Connect(s, remoteSocket)
+ if err != nil {
+ if Debug {
+ fmt.Println("ERROR Connecting from", s, "to:", remoteSocket, "ERROR:", err)
+ }
+ }
+ errChn <- err
+ }()
+
+ select {
+ case err = <-errChn:
+ if err != nil {
+ return nil, err
+ }
+ case <-timer.C:
+ return nil, fmt.Errorf("ERROR connect to %s timeout", remoteAddr)
+ }
+ f := os.NewFile(uintptr(s), "TProxyTCPClient")
+ client, err := net.FileConn(f)
+ if err != nil {
+ if Debug {
+ fmt.Println("ERROR os.NewFile", err)
+ }
+ return nil, err
+ }
+ if Debug {
+ fmt.Println("FINISHED Creating net.coo from:", client.LocalAddr().String(), "to:", client.RemoteAddr().String())
+ }
+ return client, err
+}
diff --git a/core/tproxy/tproxy_setup.sh b/core/tproxy/tproxy_setup.sh
new file mode 100644
index 0000000..c5a6fd4
--- /dev/null
+++ b/core/tproxy/tproxy_setup.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+SOURCE_BIND_IP="127.0.1.1"
+
+echo 0 > /proc/sys/net/ipv4/conf/lo/rp_filter
+echo 2 > /proc/sys/net/ipv4/conf/default/rp_filter
+echo 2 > /proc/sys/net/ipv4/conf/all/rp_filter
+echo 1 > /proc/sys/net/ipv4/conf/all/send_redirects
+echo 1 > /proc/sys/net/ipv4/conf/all/forwarding
+echo 1 > /proc/sys/net/ipv4/ip_forward
+
+# 本地的话,貌似这段不需要
+# iptables -t mangle -N DIVERT >/dev/null 2>&1
+# iptables -t mangle -F DIVERT
+# iptables -t mangle -D PREROUTING -p tcp -m socket -j DIVERT >/dev/null 2>&1
+# iptables -t mangle -A PREROUTING -p tcp -m socket -j DIVERT
+# iptables -t mangle -A DIVERT -j MARK --set-mark 1
+# iptables -t mangle -A DIVERT -j ACCEPT
+
+ip rule del fwmark 1 lookup 100
+ip rule add fwmark 1 lookup 100
+ip route del local 0.0.0.0/0 dev lo table 100
+ip route add local 0.0.0.0/0 dev lo table 100
+
+ip rule del from ${SOURCE_BIND_IP} table 101
+ip rule add from ${SOURCE_BIND_IP} table 101
+ip route del default via 127.0.0.1 dev lo table 101
+ip route add default via 127.0.0.1 dev lo table 101
+
+ip route flush cache
+ip ro flush cache
\ No newline at end of file
diff --git a/install_auto.sh b/install_auto.sh
index 295a483..32bf059 100755
--- a/install_auto.sh
+++ b/install_auto.sh
@@ -5,7 +5,7 @@ if [ -e /tmp/proxy ]; then
fi
mkdir /tmp/proxy
cd /tmp/proxy
-wget https://github.com/snail007/goproxy/releases/download/v5.4/proxy-linux-amd64.tar.gz
+wget https://github.com/snail007/goproxy/releases/download/v6.0/proxy-linux-amd64.tar.gz
# #install proxy
tar zxvf proxy-linux-amd64.tar.gz
diff --git a/main.go b/main.go
index 9d0cd56..84345f6 100644
--- a/main.go
+++ b/main.go
@@ -1,15 +1,17 @@
package main
import (
+ "fmt"
"log"
"os"
"os/signal"
+ "runtime/debug"
"syscall"
"github.com/snail007/goproxy/services"
)
-const APP_VERSION = "5.4"
+var APP_VERSION = "No Version Provided"
func main() {
err := initConfig()
@@ -32,6 +34,11 @@ func Clean(s *services.Service) {
syscall.SIGTERM,
syscall.SIGQUIT)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
for _ = range signalChan {
log.Println("Received an interrupt, stopping services...")
if s != nil && *s != nil {
diff --git a/release.sh b/release.sh
index 4f11736..b6c2217 100755
--- a/release.sh
+++ b/release.sh
@@ -1,74 +1,87 @@
#!/bin/bash
-VER="5.4"
-RELEASE="release-${VER}"
+VERSION=$(cat VERSION)
+VER="${VERSION}_$(date '+%Y%m%d%H%M%S')"
+X="-X main.APP_VERSION=$VER"
+RELEASE="release-${VERSION}"
+TRIMPATH1="/Users/snail/go/src/github.com/snail007"
+TRIMPATH=$(dirname ~/go/src/github.com/snail007)/snail007
+if [ -d "$TRIMPATH1" ];then
+ TRIMPATH=$TRIMPATH1
+fi
+OPTS="-gcflags=-trimpath=$TRIMPATH -asmflags=-trimpath=$TRIMPATH"
+
rm -rf .cert
mkdir .cert
-go build -o proxy
+go build $OPTS -ldflags "$X" -o proxy
cd .cert
../proxy keygen -C proxy
cd ..
rm -rf ${RELEASE}
mkdir ${RELEASE}
#linux
-CGO_ENABLED=0 GOOS=linux GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-amd64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm-v6.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=6 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm64-v6.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm-v7.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm64-v7.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=5 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm-v5.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=5 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm64-v5.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=mips go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=mips64 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=mips64le go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips64le.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=mipsle go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-mipsle.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=ppc64 go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-ppc64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=ppc64le go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-ppc64le.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=linux GOARCH=s390x go build -o proxy && tar zcfv "${RELEASE}/proxy-linux-s390x.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm-v6.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=6 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm64-v6.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm-v7.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm64-v7.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=5 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm-v5.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=5 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm64-v5.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm64-v8.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-arm-v8.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mips go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mips64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mips64le go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips64le.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mipsle go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mipsle.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mips GOMIPS=softfloat go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips-softfloat.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mips64 GOMIPS=softfloat go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips64-softfloat.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mips64le GOMIPS=softfloat go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mips64le-softfloat.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=mipsle GOMIPS=softfloat go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-mipsle-softfloat.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=ppc64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-ppc64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=ppc64le go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-ppc64le.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=linux GOARCH=s390x go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-linux-s390x.tar.gz" proxy direct blocked
#android
-CGO_ENABLED=0 GOOS=android GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-android-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=android GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-android-amd64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=android GOARCH=arm go build -o proxy && tar zcfv "${RELEASE}/proxy-android-arm.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=android GOARCH=arm64 go build -o proxy && tar zcfv "${RELEASE}/proxy-android-arm64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=android GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-android-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=android GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-android-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=android GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-android-arm.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=android GOARCH=arm64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-android-arm64.tar.gz" proxy direct blocked
#darwin
-CGO_ENABLED=0 GOOS=darwin GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-darwin-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-darwin-amd64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=darwin GOARCH=arm go build -o proxy && tar zcfv "${RELEASE}/proxy-darwin-arm.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o proxy && tar zcfv "${RELEASE}/proxy-darwin-arm64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=darwin GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-darwin-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-darwin-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=darwin GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-darwin-arm.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-darwin-arm64.tar.gz" proxy direct blocked
#dragonfly
-CGO_ENABLED=0 GOOS=dragonfly GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-dragonfly-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=dragonfly GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-dragonfly-amd64.tar.gz" proxy direct blocked
#freebsd
-CGO_ENABLED=0 GOOS=freebsd GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-freebsd-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-freebsd-amd64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=freebsd GOARCH=arm go build -o proxy && tar zcfv "${RELEASE}/proxy-freebsd-arm.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=freebsd GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-freebsd-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-freebsd-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=freebsd GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-freebsd-arm.tar.gz" proxy direct blocked
#nacl
-CGO_ENABLED=0 GOOS=nacl GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-nacl-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=nacl GOARCH=amd64p32 go build -o proxy && tar zcfv "${RELEASE}/proxy-nacl-amd64p32.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=nacl GOARCH=arm go build -o proxy && tar zcfv "${RELEASE}/proxy-nacl-arm.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=nacl GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-nacl-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=nacl GOARCH=amd64p32 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-nacl-amd64p32.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=nacl GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-nacl-arm.tar.gz" proxy direct blocked
#netbsd
-CGO_ENABLED=0 GOOS=netbsd GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-netbsd-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=netbsd GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-netbsd-amd64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=netbsd GOARCH=arm go build -o proxy && tar zcfv "${RELEASE}/proxy-netbsd-arm.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=netbsd GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-netbsd-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=netbsd GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-netbsd-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=netbsd GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-netbsd-arm.tar.gz" proxy direct blocked
#openbsd
-CGO_ENABLED=0 GOOS=openbsd GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-openbsd-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=openbsd GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-openbsd-amd64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=openbsd GOARCH=arm go build -o proxy && tar zcfv "${RELEASE}/proxy-openbsd-arm.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=openbsd GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-openbsd-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=openbsd GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-openbsd-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=openbsd GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-openbsd-arm.tar.gz" proxy direct blocked
#plan9
-CGO_ENABLED=0 GOOS=plan9 GOARCH=386 go build -o proxy && tar zcfv "${RELEASE}/proxy-plan9-386.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=plan9 GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-plan9-amd64.tar.gz" proxy direct blocked
-CGO_ENABLED=0 GOOS=plan9 GOARCH=arm go build -o proxy && tar zcfv "${RELEASE}/proxy-plan9-arm.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=plan9 GOARCH=386 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-plan9-386.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=plan9 GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-plan9-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=plan9 GOARCH=arm go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-plan9-arm.tar.gz" proxy direct blocked
#solaris
-CGO_ENABLED=0 GOOS=solaris GOARCH=amd64 go build -o proxy && tar zcfv "${RELEASE}/proxy-solaris-amd64.tar.gz" proxy direct blocked
+CGO_ENABLED=0 GOOS=solaris GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy && tar zcfv "${RELEASE}/proxy-solaris-amd64.tar.gz" proxy direct blocked
#windows
-CGO_ENABLED=0 GOOS=windows GOARCH=386 go build -ldflags="-H=windowsgui" -o proxy-noconsole.exe
-CGO_ENABLED=0 GOOS=windows GOARCH=386 go build -o proxy.exe && tar zcfv "${RELEASE}/proxy-windows-386.tar.gz" proxy.exe proxy-noconsole.exe direct blocked .cert/proxy.crt .cert/proxy.key
-CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -ldflags="-H=windowsgui" -o proxy-noconsole.exe
-CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o proxy.exe && tar zcfv "${RELEASE}/proxy-windows-amd64.tar.gz" proxy.exe proxy-noconsole.exe direct blocked .cert/proxy.crt .cert/proxy.key
+CGO_ENABLED=0 GOOS=windows GOARCH=386 go build $OPTS -ldflags="-H=windowsgui $X" -o proxy-noconsole.exe
+CGO_ENABLED=0 GOOS=windows GOARCH=386 go build $OPTS -ldflags "$X" -o proxy.exe && tar zcfv "${RELEASE}/proxy-windows-386.tar.gz" proxy.exe proxy-noconsole.exe direct blocked .cert/proxy.crt .cert/proxy.key
+CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build $OPTS -ldflags="-H=windowsgui $X" -o proxy-noconsole.exe
+CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build $OPTS -ldflags "$X" -o proxy.exe && tar zcfv "${RELEASE}/proxy-windows-amd64.tar.gz" proxy.exe proxy-noconsole.exe direct blocked .cert/proxy.crt .cert/proxy.key
rm -rf proxy proxy.exe proxy-noconsole.exe .cert
#todo
-#1.release.sh VER="xxx"
-#2.main.go APP_VERSION="xxx"
-#3.install_auto.sh goproxy/releases/download/vxxx
-#4.README goproxy/releases/download/vxxx
+#1.install_auto.sh goproxy/releases/download/vxxx
+#2.README goproxy/releases/download/vxxx
diff --git a/sdk/CHANGELOG b/sdk/CHANGELOG
index 9336b59..ec15bce 100644
--- a/sdk/CHANGELOG
+++ b/sdk/CHANGELOG
@@ -1,21 +1,3 @@
-SDK更新日志
-v5.4
-1.去掉了无用参数
-
-
-v5.3
-1.增加了支持日志输出回调的方法:
- StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
-2.优化了socks_client握手端口判断,避免了sstap测试UDP失败的问题..
-3.修复了HTTP(S)\SPS反向代理无法正常工作的问题.
-4.优化了智能判断,减少不必要的DNS解析.
-5.重构了SOCKS和SPS的UDP功能,基于UDP的游戏加速嗖嗖的.
-
-v4.9
-1.修复了HTTP Basic代理返回不合适的头部,导致浏览器不会弹框,个别代理插件无法认证的问题.
-2.内网穿透切换smux到yamux.
-3.优化了HTTP(S)\SOCKS5代理--always的处理逻辑.
-
v4.8
1.修复了多个服务同时开启日志,只会输出到最后一个日志文件的bug.
2.增加了获取sdk版本的Version()方法.
diff --git a/sdk/README.md b/sdk/README.md
index f21503e..f6a158b 100644
--- a/sdk/README.md
+++ b/sdk/README.md
@@ -25,7 +25,7 @@ proxy使用gombile实现了一份go代码编译为android和ios平台下面可
#### 1.导入包
```java
-import snail007.proxy.Proxy
+import snail007.proxy.Porxy
```
#### 2.启动一个服务
@@ -200,7 +200,7 @@ int main() {
```
#### 编译test-proxy.c ####
-`export LD_LIBRARY_PATH=./ && gcc -o test-proxy test.c libproxy-sdk.so`
+`export LD_LIBRARY_PATH=./ && gcc -o test-proxy test-proxy.c libproxy-sdk.so`
#### 执行 ####
`./test-proxy`
diff --git a/sdk/android-ios/dns.go b/sdk/android-ios/dns.go
index 375016d..9d74926 100644
--- a/sdk/android-ios/dns.go
+++ b/sdk/android-ios/dns.go
@@ -58,6 +58,11 @@ func (s *DNS) InitService() (err error) {
s.cache = gocache.New(time.Second*time.Duration(*s.cfg.DNSTTL), time.Second*60)
s.cache.LoadFile(*s.cfg.CacheFile)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
for {
select {
case <-s.exitSig:
@@ -76,7 +81,7 @@ func (s *DNS) InitService() (err error) {
nil,
&net.Dialer{
Timeout: 5 * time.Second,
- KeepAlive: 2 * time.Second,
+ KeepAlive: 5 * time.Second,
},
)
if err != nil {
@@ -117,7 +122,7 @@ func (s *DNS) StopService() {
if e != nil {
s.log.Printf("stop dns service crashed,%s", e)
} else {
- s.log.Printf("service dns stopped")
+ s.log.Printf("service dns stoped")
}
}()
Stop(s.serviceKey)
@@ -135,6 +140,11 @@ func (s *DNS) Start(args interface{}, log *logger.Logger) (err error) {
}
dns.HandleFunc(".", s.callback)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
log.Printf("dns server on udp %s", *s.cfg.Local)
err := dns.ListenAndServe(*s.cfg.Local, "udp", nil)
if err != nil {
diff --git a/sdk/android-ios/release_android.sh b/sdk/android-ios/release_android.sh
index f739580..92b8a9d 100755
--- a/sdk/android-ios/release_android.sh
+++ b/sdk/android-ios/release_android.sh
@@ -1,5 +1,8 @@
#/bin/bash
-VER="v5.4"
+VERSION=$(cat ../../VERSION)
+VER="${VERSION}_$(date '+%Y%m%d%H%M%S')"
+X="-X github.com/snail007/goproxy/sdk/android-ios.SDK_VERSION=$VER -X main.APP_VERSION=$VER"
+
rm -rf sdk-android-*.tar.gz
rm -rf android
mkdir android
@@ -14,11 +17,11 @@ mkdir android
#go get -v golang.org/x/mobile/cmd/gomobile
#gomobile init
-gomobile bind -v -target=android -javapkg=snail007 -ldflags="-s -w"
+gomobile bind -v -target=android -javapkg=snail007 -ldflags="-s -w $X"
mv proxy.aar android/snail007.goproxy.sdk.aar
mv proxy-sources.jar android/snail007.goproxy.sdk-sources.jar
cp ../README.md android
-tar zcfv sdk-android-${VER}.tar.gz android
+tar zcfv sdk-android-${VERSION}.tar.gz android
rm -rf android
echo "done."
diff --git a/sdk/android-ios/release_ios.sh b/sdk/android-ios/release_ios.sh
index 546a1c3..c0e4205 100755
--- a/sdk/android-ios/release_ios.sh
+++ b/sdk/android-ios/release_ios.sh
@@ -1,14 +1,17 @@
#/bin/bash
-VER="v5.4"
+VERSION=$(cat ../../VERSION)
+VER="${VERSION}_$(date '+%Y%m%d%H%M%S')"
+X="-X github.com/snail007/goproxy/sdk/android-ios.SDK_VERSION=$VER -X main.APP_VERSION=$VER"
+
rm -rf sdk-ios-*.tar.gz
rm -rf ios
mkdir ios
#ios XCode required
-gomobile bind -v -target=ios -ldflags="-s -w"
+gomobile bind -v -target=ios -ldflags="-s -w $X"
mv Proxy.framework ios
cp ../README.md ios
-tar zcfv sdk-ios-${VER}.tar.gz ios
+tar zcfv sdk-ios-${VERSION}.tar.gz ios
rm -rf ios
echo "done."
diff --git a/sdk/android-ios/sdk.go b/sdk/android-ios/sdk.go
index c00d042..7178e03 100644
--- a/sdk/android-ios/sdk.go
+++ b/sdk/android-ios/sdk.go
@@ -13,18 +13,20 @@ import (
"github.com/snail007/goproxy/services"
httpx "github.com/snail007/goproxy/services/http"
"github.com/snail007/goproxy/services/kcpcfg"
+ keygenx "github.com/snail007/goproxy/services/keygen"
mux "github.com/snail007/goproxy/services/mux"
socksx "github.com/snail007/goproxy/services/socks"
spsx "github.com/snail007/goproxy/services/sps"
tcpx "github.com/snail007/goproxy/services/tcp"
- tunnel "github.com/snail007/goproxy/services/tunnel"
+ tunnelx "github.com/snail007/goproxy/services/tunnel"
udpx "github.com/snail007/goproxy/services/udp"
+
kcp "github.com/xtaci/kcp-go"
"golang.org/x/crypto/pbkdf2"
kingpin "gopkg.in/alecthomas/kingpin.v2"
)
-const SDK_VERSION = "5.4"
+var SDK_VERSION = "No Version Provided"
var (
app *kingpin.Application
@@ -62,9 +64,9 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
//define args
tcpArgs := tcpx.TCPArgs{}
httpArgs := httpx.HTTPArgs{}
- tunnelServerArgs := tunnel.TunnelServerArgs{}
- tunnelClientArgs := tunnel.TunnelClientArgs{}
- tunnelBridgeArgs := tunnel.TunnelBridgeArgs{}
+ tunnelServerArgs := tunnelx.TunnelServerArgs{}
+ tunnelClientArgs := tunnelx.TunnelClientArgs{}
+ tunnelBridgeArgs := tunnelx.TunnelBridgeArgs{}
muxServerArgs := mux.MuxServerArgs{}
muxClientArgs := mux.MuxClientArgs{}
muxBridgeArgs := mux.MuxBridgeArgs{}
@@ -72,6 +74,7 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
socksArgs := socksx.SocksArgs{}
spsArgs := spsx.SPSArgs{}
dnsArgs := DNSArgs{}
+ keygenArgs := keygenx.KeygenArgs{}
kcpArgs := kcpcfg.KCPConfigArgs{}
//build srvice args
app = kingpin.New("proxy", "happy with proxy")
@@ -81,8 +84,8 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
nolog := app.Flag("nolog", "turn off logging").Default("false").Bool()
kcpArgs.Key = app.Flag("kcp-key", "pre-shared secret between client and server").Default("secrect").String()
kcpArgs.Crypt = app.Flag("kcp-method", "encrypt/decrypt method, can be: aes, aes-128, aes-192, salsa20, blowfish, twofish, cast5, 3des, tea, xtea, xor, sm4, none").Default("aes").Enum("aes", "aes-128", "aes-192", "salsa20", "blowfish", "twofish", "cast5", "3des", "tea", "xtea", "xor", "sm4", "none")
- kcpArgs.Mode = app.Flag("kcp-mode", "profiles: fast3, fast2, fast, normal, manual").Default("fast3").Enum("fast3", "fast2", "fast", "normal", "manual")
- kcpArgs.MTU = app.Flag("kcp-mtu", "set maximum transmission unit for UDP packets").Default("1350").Int()
+ kcpArgs.Mode = app.Flag("kcp-mode", "profiles: fast3, fast2, fast, normal, manual").Default("fast").Enum("fast3", "fast2", "fast", "normal", "manual")
+ kcpArgs.MTU = app.Flag("kcp-mtu", "set maximum transmission unit for UDP packets").Default("450").Int()
kcpArgs.SndWnd = app.Flag("kcp-sndwnd", "set send window size(num of packets)").Default("1024").Int()
kcpArgs.RcvWnd = app.Flag("kcp-rcvwnd", "set receive window size(num of packets)").Default("1024").Int()
kcpArgs.DataShard = app.Flag("kcp-ds", "set reed-solomon erasure coding - datashard").Default("10").Int()
@@ -99,7 +102,7 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
//########http#########
http := app.Command("http", "proxy on http mode")
- httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
+ httpArgs.Parent = http.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').Strings()
httpArgs.CaCertFile = http.Flag("ca", "ca cert file for tls").Default("").String()
httpArgs.CertFile = http.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
httpArgs.KeyFile = http.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
@@ -130,7 +133,14 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
httpArgs.ParentKey = http.Flag("parent-key", "the password for auto encrypt/decrypt parent connection data").Short('Z').Default("").String()
httpArgs.LocalCompress = http.Flag("local-compress", "auto compress/decompress data on local connection").Short('m').Default("false").Bool()
httpArgs.ParentCompress = http.Flag("parent-compress", "auto compress/decompress data on parent connection").Short('M').Default("false").Bool()
-
+ httpArgs.LoadBalanceMethod = http.Flag("lb-method", "load balance method when use multiple parent,can be ").Default("hash").Enum("roundrobin", "weight", "leastconn", "leasttime", "hash")
+ httpArgs.LoadBalanceTimeout = http.Flag("lb-timeout", "tcp milliseconds timeout of connecting to parent").Default("500").Int()
+ httpArgs.LoadBalanceRetryTime = http.Flag("lb-retrytime", "sleep time milliseconds after checking").Default("1000").Int()
+ httpArgs.LoadBalanceHashTarget = http.Flag("lb-hashtarget", "use target address to choose parent for LB").Default("false").Bool()
+ httpArgs.LoadBalanceOnlyHA = http.Flag("lb-onlyha", "use only `high availability mode` to choose parent for LB").Default("false").Bool()
+ httpArgs.RateLimit = http.Flag("rate-limit", "rate limit (bytes/second) of each connection, such as: 100K 1.5M . 0 means no limitation").Short('l').Default("0").String()
+ httpArgs.BindListen = http.Flag("bind-listen", "using listener binding IP when connect to target").Short('B').Default("false").Bool()
+ httpArgs.Debug = debug
//########tcp#########
tcp := app.Command("tcp", "proxy on tcp mode")
tcpArgs.Parent = tcp.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
@@ -139,6 +149,7 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
tcpArgs.Timeout = tcp.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Short('e').Default("2000").Int()
tcpArgs.ParentType = tcp.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "udp", "kcp")
tcpArgs.LocalType = tcp.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")
+ tcpArgs.CheckParentInterval = tcp.Flag("check-parent-interval", "check if proxy is okay every interval seconds,zero: means no check").Short('I').Default("3").Int()
tcpArgs.Local = tcp.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
tcpArgs.Jumper = tcp.Flag("jumper", "https or socks5 proxies used when connecting to parent, only worked of -T is tls or tcp, format is https://username:password@host:port https://host:port or socks5://username:password@host:port socks5://host:port").Short('J').Default("").String()
@@ -215,7 +226,7 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
//########ssh#########
socks := app.Command("socks", "proxy on ssh mode")
- socksArgs.Parent = socks.Flag("parent", "parent ssh address, such as: \"23.32.32.19:22\"").Default("").Short('P').String()
+ socksArgs.Parent = socks.Flag("parent", "parent ssh address, such as: \"23.32.32.19:22\"").Default("").Short('P').Strings()
socksArgs.ParentType = socks.Flag("parent-type", "parent protocol type ").Default("tcp").Short('T').Enum("tls", "tcp", "kcp", "ssh")
socksArgs.LocalType = socks.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")
socksArgs.Local = socks.Flag("local", "local ip:port to listen").Short('p').Default(":33080").String()
@@ -225,7 +236,7 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
socksArgs.SSHUser = socks.Flag("ssh-user", "user for ssh").Short('u').Default("").String()
socksArgs.SSHKeyFile = socks.Flag("ssh-key", "private key file for ssh").Short('S').Default("").String()
socksArgs.SSHKeyFileSalt = socks.Flag("ssh-keysalt", "salt of ssh private key").Short('s').Default("").String()
- socksArgs.SSHPassword = socks.Flag("ssh-password", "password for ssh").Short('A').Default("").String()
+ socksArgs.SSHPassword = socks.Flag("ssh-password", "password for ssh").Short('D').Default("").String()
socksArgs.Always = socks.Flag("always", "always use parent proxy").Default("false").Bool()
socksArgs.Timeout = socks.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Default("5000").Int()
socksArgs.Interval = socks.Flag("interval", "check domain if blocked every interval seconds").Default("10").Int()
@@ -238,16 +249,25 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
socksArgs.AuthURLTimeout = socks.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int()
socksArgs.AuthURLOkCode = socks.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int()
socksArgs.AuthURLRetry = socks.Flag("auth-retry", "access 'auth-url' fail and retry count").Default("0").Int()
+ socksArgs.ParentAuth = socks.Flag("parent-auth", "parent socks auth username and password, such as: -A user1:pass1").Short('A').String()
socksArgs.DNSAddress = socks.Flag("dns-address", "if set this, proxy will use this dns for resolve doamin").Short('q').Default("").String()
socksArgs.DNSTTL = socks.Flag("dns-ttl", "caching seconds of dns query result").Short('e').Default("300").Int()
socksArgs.LocalKey = socks.Flag("local-key", "the password for auto encrypt/decrypt local connection data").Short('z').Default("").String()
socksArgs.ParentKey = socks.Flag("parent-key", "the password for auto encrypt/decrypt parent connection data").Short('Z').Default("").String()
socksArgs.LocalCompress = socks.Flag("local-compress", "auto compress/decompress data on local connection").Short('m').Default("false").Bool()
socksArgs.ParentCompress = socks.Flag("parent-compress", "auto compress/decompress data on parent connection").Short('M').Default("false").Bool()
+ socksArgs.LoadBalanceMethod = socks.Flag("lb-method", "load balance method when use multiple parent,can be ").Default("hash").Enum("roundrobin", "weight", "leastconn", "leasttime", "hash")
+ socksArgs.LoadBalanceTimeout = socks.Flag("lb-timeout", "tcp milliseconds timeout of connecting to parent").Default("500").Int()
+ socksArgs.LoadBalanceRetryTime = socks.Flag("lb-retrytime", "sleep time milliseconds after checking").Default("1000").Int()
+ socksArgs.LoadBalanceHashTarget = socks.Flag("lb-hashtarget", "use target address to choose parent for LB").Default("false").Bool()
+ socksArgs.LoadBalanceOnlyHA = socks.Flag("lb-onlyha", "use only `high availability mode` to choose parent for LB").Default("false").Bool()
+ socksArgs.RateLimit = socks.Flag("rate-limit", "rate limit (bytes/second) of each connection, such as: 100K 1.5M . 0 means no limitation").Short('l').Default("0").String()
+ socksArgs.BindListen = socks.Flag("bind-listen", "using listener binding IP when connect to target").Short('B').Default("false").Bool()
+ socksArgs.Debug = debug
//########socks+http(s)#########
sps := app.Command("sps", "proxy on socks+http(s) mode")
- spsArgs.Parent = sps.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
+ spsArgs.Parent = sps.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').Strings()
spsArgs.CertFile = sps.Flag("cert", "cert file for tls").Short('C').Default("proxy.crt").String()
spsArgs.KeyFile = sps.Flag("key", "key file for tls").Short('K').Default("proxy.key").String()
spsArgs.CaCertFile = sps.Flag("ca", "ca cert file for tls").Default("").String()
@@ -255,12 +275,12 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
spsArgs.ParentType = sps.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "kcp")
spsArgs.LocalType = sps.Flag("local-type", "local protocol type ").Default("tcp").Short('t').Enum("tls", "tcp", "kcp")
spsArgs.Local = sps.Flag("local", "local ip:port to listen,multiple address use comma split,such as: 0.0.0.0:80,0.0.0.0:443").Short('p').Default(":33080").String()
- spsArgs.ParentServiceType = sps.Flag("parent-service-type", "parent service type ").Short('S').Enum("http", "socks")
+ spsArgs.ParentServiceType = sps.Flag("parent-service-type", "parent service type ").Short('S').Enum("http", "socks", "ss")
spsArgs.DNSAddress = sps.Flag("dns-address", "if set this, proxy will use this dns for resolve doamin").Short('q').Default("").String()
spsArgs.DNSTTL = sps.Flag("dns-ttl", "caching seconds of dns query result").Short('e').Default("300").Int()
spsArgs.AuthFile = sps.Flag("auth-file", "http basic auth file,\"username:password\" each line in file").Short('F').String()
spsArgs.Auth = sps.Flag("auth", "socks auth username and password, mutiple user repeat -a ,such as: -a user1:pass1 -a user2:pass2").Short('a').Strings()
- spsArgs.LocalIPS = sps.Flag("local bind ips", "if your host behind a nat,set your public ip here avoid dead loop").Short('g').Strings()
+ spsArgs.LocalIPS = sps.Flag("local-bind-ips", "if your host behind a nat,set your public ip here avoid dead loop").Short('g').Strings()
spsArgs.AuthURL = sps.Flag("auth-url", "auth username and password will send to this url,response http code equal to 'auth-code' means ok,others means fail.").Default("").String()
spsArgs.AuthURLTimeout = sps.Flag("auth-timeout", "access 'auth-url' timeout milliseconds").Default("3000").Int()
spsArgs.AuthURLOkCode = sps.Flag("auth-code", "access 'auth-url' success http code").Default("204").Int()
@@ -270,8 +290,21 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
spsArgs.ParentKey = sps.Flag("parent-key", "the password for auto encrypt/decrypt parent connection data").Short('Z').Default("").String()
spsArgs.LocalCompress = sps.Flag("local-compress", "auto compress/decompress data on local connection").Short('m').Default("false").Bool()
spsArgs.ParentCompress = sps.Flag("parent-compress", "auto compress/decompress data on parent connection").Short('M').Default("false").Bool()
+ spsArgs.SSMethod = sps.Flag("ss-method", "the following methods are supported: aes-128-cfb, aes-192-cfb, aes-256-cfb, bf-cfb, cast5-cfb, des-cfb, rc4-md5, rc4-md5-6, chacha20, salsa20, rc4, table, des-cfb, chacha20-ietf; if you use ss client , \"-t tcp\" is required").Short('h').Default("aes-256-cfb").String()
+ spsArgs.SSKey = sps.Flag("ss-key", "if you use ss client , \"-t tcp\" is required").Short('j').Default("sspassword").String()
+ spsArgs.ParentSSMethod = sps.Flag("parent-ss-method", "the following methods are supported: aes-128-cfb, aes-192-cfb, aes-256-cfb, bf-cfb, cast5-cfb, des-cfb, rc4-md5, rc4-md5-6, chacha20, salsa20, rc4, table, des-cfb, chacha20-ietf; if you use ss server as parent, \"-T tcp\" is required").Short('H').Default("aes-256-cfb").String()
+ spsArgs.ParentSSKey = sps.Flag("parent-ss-key", "if you use ss server as parent, \"-T tcp\" is required").Short('J').Default("sspassword").String()
spsArgs.DisableHTTP = sps.Flag("disable-http", "disable http(s) proxy").Default("false").Bool()
spsArgs.DisableSocks5 = sps.Flag("disable-socks", "disable socks proxy").Default("false").Bool()
+ spsArgs.DisableSS = sps.Flag("disable-ss", "disable ss proxy").Default("false").Bool()
+ spsArgs.LoadBalanceMethod = sps.Flag("lb-method", "load balance method when use multiple parent,can be ").Default("hash").Enum("roundrobin", "weight", "leastconn", "leasttime", "hash")
+ spsArgs.LoadBalanceTimeout = sps.Flag("lb-timeout", "tcp milliseconds timeout of connecting to parent").Default("500").Int()
+ spsArgs.LoadBalanceRetryTime = sps.Flag("lb-retrytime", "sleep time milliseconds after checking").Default("1000").Int()
+ spsArgs.LoadBalanceHashTarget = sps.Flag("lb-hashtarget", "use target address to choose parent for LB").Default("false").Bool()
+ spsArgs.LoadBalanceOnlyHA = sps.Flag("lb-onlyha", "use only `high availability mode` to choose parent for LB").Default("false").Bool()
+ spsArgs.RateLimit = sps.Flag("rate-limit", "rate limit (bytes/second) of each connection, such as: 100K 1.5M . 0 means no limitation").Short('l').Default("0").String()
+ spsArgs.Debug = debug
+
//########dns#########
dns := app.Command("dns", "proxy on dns server mode")
dnsArgs.Parent = dns.Flag("parent", "parent address, such as: \"23.32.32.19:28008\"").Default("").Short('P').String()
@@ -280,7 +313,7 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
dnsArgs.CaCertFile = dns.Flag("ca", "ca cert file for tls").Default("").String()
dnsArgs.Timeout = dns.Flag("timeout", "tcp timeout milliseconds when connect to real server or parent proxy").Short('i').Default("2000").Int()
dnsArgs.ParentType = dns.Flag("parent-type", "parent protocol type ").Short('T').Enum("tls", "tcp", "kcp")
- dnsArgs.Local = dns.Flag("local", "local ip:port to listen,multiple address use comma split,such as: 0.0.0.0:80,0.0.0.0:443").Short('p').Default(":33080").String()
+ dnsArgs.Local = dns.Flag("local", "local ip:port to listen,multiple address use comma split,such as: 0.0.0.0:80,0.0.0.0:443").Short('p').Default(":53").String()
dnsArgs.ParentServiceType = dns.Flag("parent-service-type", "parent service type ").Short('S').Enum("http", "socks")
dnsArgs.RemoteDNSAddress = dns.Flag("dns-address", "remote dns for resolve doamin").Short('q').Default("8.8.8.8:53").String()
dnsArgs.DNSTTL = dns.Flag("dns-ttl", "caching seconds of dns query result").Short('e').Default("300").Int()
@@ -290,6 +323,14 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
dnsArgs.CacheFile = dns.Flag("cache-file", "dns result cached file").Short('f').Default(filepath.Join(path.Dir(os.Args[0]), "cache.dat")).String()
dnsArgs.LocalSocks5Port = dns.Flag("socks-port", "local socks5 port").Short('s').Default("65501").String()
+ //########keygen#########
+ keygen := app.Command("keygen", "create certificate for proxy")
+ keygenArgs.CommonName = keygen.Flag("cn", "common name").Short('n').Default("").String()
+ keygenArgs.CaName = keygen.Flag("ca", "ca name").Short('C').Default("").String()
+ keygenArgs.CertName = keygen.Flag("cert", "cert name of sign to create").Short('c').Default("").String()
+ keygenArgs.SignDays = keygen.Flag("days", "days of sign").Short('d').Default("365").Int()
+ keygenArgs.Sign = keygen.Flag("sign", "cert is to signin").Short('s').Default("false").Bool()
+
//parse args
_args := strings.Fields(strings.Trim(serviceArgsStr, " "))
args := []string{}
@@ -387,11 +428,11 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
case "udp":
services.Regist(serviceID, udpx.NewUDP(), udpArgs, log)
case "tserver":
- services.Regist(serviceID, tunnel.NewTunnelServerManager(), tunnelServerArgs, log)
+ services.Regist(serviceID, tunnelx.NewTunnelServerManager(), tunnelServerArgs, log)
case "tclient":
- services.Regist(serviceID, tunnel.NewTunnelClient(), tunnelClientArgs, log)
+ services.Regist(serviceID, tunnelx.NewTunnelClient(), tunnelClientArgs, log)
case "tbridge":
- services.Regist(serviceID, tunnel.NewTunnelBridge(), tunnelBridgeArgs, log)
+ services.Regist(serviceID, tunnelx.NewTunnelBridge(), tunnelBridgeArgs, log)
case "server":
services.Regist(serviceID, mux.NewMuxServerManager(), muxServerArgs, log)
case "client":
@@ -403,7 +444,7 @@ func StartWithLog(serviceID, serviceArgsStr string, loggerCallback LogCallback)
case "sps":
services.Regist(serviceID, spsx.NewSPS(), spsArgs, log)
case "dns":
- services.Regist(serviceName, NewDNS(), dnsArgs, log)
+ services.Regist(serviceID, NewDNS(), dnsArgs, log)
}
_, err = services.Run(serviceID, nil)
if err != nil {
diff --git a/sdk/windows-linux/release_linux.sh b/sdk/windows-linux/release_linux.sh
index 232916e..cb91e75 100755
--- a/sdk/windows-linux/release_linux.sh
+++ b/sdk/windows-linux/release_linux.sh
@@ -1,21 +1,29 @@
#/bin/bash
-VER="v5.4"
+VERSION=$(cat ../../VERSION)
+VER="${VERSION}_$(date '+%Y%m%d%H%M%S')"
+X="-X github.com/snail007/goproxy/sdk/android-ios.SDK_VERSION=$VER -X main.APP_VERSION=$VER"
+TRIMPATH1="/Users/snail/go/src/github.com/snail007"
+TRIMPATH=$(dirname ~/go/src/github.com/snail007)/snail007
+if [ -d "$TRIMPATH1" ];then
+ TRIMPATH=$TRIMPATH1
+fi
+OPTS="-gcflags=-trimpath=$TRIMPATH -asmflags=-trimpath=$TRIMPATH"
rm -rf sdk-linux-*.tar.gz
rm -rf README.md libproxy-sdk.so libproxy-sdk.h libproxy-sdk.a
#linux 32bit
-CGO_ENABLED=1 GOARCH=386 GOOS=linux go build -buildmode=c-archive -ldflags "-s -w" -o libproxy-sdk.a sdk.go
-CGO_ENABLED=1 GOARCH=386 GOOS=linux go build -buildmode=c-shared -ldflags "-s -w" -o libproxy-sdk.so sdk.go
+CGO_ENABLED=1 GOARCH=386 GOOS=linux go build -buildmode=c-archive $OPTS -ldflags "-s -w $X" -o libproxy-sdk.a sdk.go
+CGO_ENABLED=1 GOARCH=386 GOOS=linux go build -buildmode=c-shared $OPTS -ldflags "-s -w $X" -o libproxy-sdk.so sdk.go
cp ../README.md .
-tar zcf sdk-linux-32bit-${VER}.tar.gz README.md libproxy-sdk.so libproxy-sdk.a libproxy-sdk.h
+tar zcf sdk-linux-32bit-${VERSION}.tar.gz README.md libproxy-sdk.so libproxy-sdk.a libproxy-sdk.h
rm -rf README.md libproxy-sdk.so libproxy-sdk.h libproxy-sdk.a
#linux 64bit
-CGO_ENABLED=1 GOARCH=amd64 GOOS=linux go build -buildmode=c-archive -ldflags "-s -w" -o libproxy-sdk.a sdk.go
-CGO_ENABLED=1 GOARCH=amd64 GOOS=linux go build -buildmode=c-shared -ldflags "-s -w" -o libproxy-sdk.so sdk.go
+CGO_ENABLED=1 GOARCH=amd64 GOOS=linux go build -buildmode=c-archive $OPTS -ldflags "-s -w $X" -o libproxy-sdk.a sdk.go
+CGO_ENABLED=1 GOARCH=amd64 GOOS=linux go build -buildmode=c-shared $OPTS -ldflags "-s -w $X" -o libproxy-sdk.so sdk.go
cp ../README.md .
-tar zcf sdk-linux-64bit-${VER}.tar.gz README.md libproxy-sdk.so libproxy-sdk.a libproxy-sdk.h
+tar zcf sdk-linux-64bit-${VERSION}.tar.gz README.md libproxy-sdk.so libproxy-sdk.a libproxy-sdk.h
rm -rf README.md libproxy-sdk.so libproxy-sdk.h libproxy-sdk.a
echo "done."
diff --git a/sdk/windows-linux/release_mac.sh b/sdk/windows-linux/release_mac.sh
index cad6c98..18bfaf9 100755
--- a/sdk/windows-linux/release_mac.sh
+++ b/sdk/windows-linux/release_mac.sh
@@ -1,13 +1,21 @@
#/bin/bash
-VER="v5.4"
+VERSION=$(cat ../../VERSION)
+VER="${VERSION}_$(date '+%Y%m%d%H%M%S')"
+X="-X github.com/snail007/goproxy/sdk/android-ios.SDK_VERSION=$VER -X main.APP_VERSION=$VER"
+TRIMPATH1="/Users/snail/go/src/github.com/snail007"
+TRIMPATH=$(dirname ~/go/src/github.com/snail007)/snail007
+if [ -d "$TRIMPATH1" ];then
+ TRIMPATH=$TRIMPATH1
+fi
+OPTS="-gcflags=-trimpath=$TRIMPATH -asmflags=-trimpath=$TRIMPATH"
rm -rf *.tar.gz
rm -rf README.md libproxy-sdk.dylib libproxy-sdk.h
#mac , macos required
-CGO_ENABLED=1 GOARCH=amd64 GOOS=darwin go build -buildmode=c-shared -ldflags "-s -w" -o libproxy-sdk.dylib sdk.go
+CGO_ENABLED=1 GOARCH=amd64 GOOS=darwin go build -buildmode=c-shared $OPTS -ldflags "-s -w $X" -o libproxy-sdk.dylib sdk.go
cp ../README.md .
-tar zcf sdk-mac-${VER}.tar.gz README.md libproxy-sdk.dylib libproxy-sdk.h
+tar zcf sdk-mac-${VERSION}.tar.gz README.md libproxy-sdk.dylib libproxy-sdk.h
rm -rf README.md libproxy-sdk.dylib libproxy-sdk.h
echo "done."
diff --git a/sdk/windows-linux/release_windows.sh b/sdk/windows-linux/release_windows.sh
index 5a736ac..697752c 100755
--- a/sdk/windows-linux/release_windows.sh
+++ b/sdk/windows-linux/release_windows.sh
@@ -1,5 +1,13 @@
#/bin/bash
-VER="v5.4"
+VERSION=$(cat ../../VERSION)
+VER="${VERSION}_$(date '+%Y%m%d%H%M%S')"
+X="-X github.com/snail007/goproxy/sdk/android-ios.SDK_VERSION=$VER -X main.APP_VERSION=$VER"
+TRIMPATH1="/Users/snail/go/src/github.com/snail007"
+TRIMPATH=$(dirname ~/go/src/github.com/snail007)/snail007
+if [ -d "$TRIMPATH1" ];then
+ TRIMPATH=$TRIMPATH1
+fi
+OPTS="-gcflags=-trimpath=$TRIMPATH -asmflags=-trimpath=$TRIMPATH"
#sudo rm /usr/local/go
#sudo ln -s /usr/local/go1.10.1 /usr/local/go
@@ -11,15 +19,15 @@ rm -rf README.md proxy-sdk.h proxy-sdk.dll
#apt-get install gcc-mingw-w64
#windows 64bit
-CC=x86_64-w64-mingw32-gcc GOARCH=amd64 CGO_ENABLED=1 GOOS=windows go build -buildmode=c-shared -ldflags "-s -w" -o proxy-sdk.dll sdk.go
+CC=x86_64-w64-mingw32-gcc GOARCH=amd64 CGO_ENABLED=1 GOOS=windows go build $OPTS -buildmode=c-shared -ldflags "-s -w $X" -o proxy-sdk.dll sdk.go
cp ../README.md .
-tar zcf sdk-windows-64bit-${VER}.tar.gz README.md proxy-sdk.dll proxy-sdk.h ieshims.dll
+tar zcf sdk-windows-64bit-${VERSION}.tar.gz README.md proxy-sdk.dll proxy-sdk.h ieshims.dll
rm -rf README.md proxy-sdk.h proxy-sdk.dll
#windows 32bit
-CC=i686-w64-mingw32-gcc-win32 GOARCH=386 CGO_ENABLED=1 GOOS=windows go build -buildmode=c-shared -ldflags "-s -w" -o proxy-sdk.dll sdk.go
+CC=i686-w64-mingw32-gcc-win32 GOARCH=386 CGO_ENABLED=1 GOOS=windows go build $OPTS -buildmode=c-shared -ldflags "-s -w $X" -o proxy-sdk.dll sdk.go
cp ../README.md .
-tar zcf sdk-windows-32bit-${VER}.tar.gz README.md proxy-sdk.dll proxy-sdk.h ieshims.dll
+tar zcf sdk-windows-32bit-${VERSION}.tar.gz README.md proxy-sdk.dll proxy-sdk.h ieshims.dll
rm -rf README.md proxy-sdk.h proxy-sdk.dll
#sudo rm /usr/local/go
diff --git a/services/http/http.go b/services/http/http.go
index 0d33093..5685342 100644
--- a/services/http/http.go
+++ b/services/http/http.go
@@ -1,6 +1,7 @@
package http
import (
+ "crypto/tls"
"fmt"
"io"
"io/ioutil"
@@ -13,6 +14,12 @@ import (
"github.com/snail007/goproxy/services"
"github.com/snail007/goproxy/services/kcpcfg"
+ "github.com/snail007/goproxy/utils/datasize"
+ "github.com/snail007/goproxy/utils/dnsx"
+ "github.com/snail007/goproxy/utils/iolimiter"
+ "github.com/snail007/goproxy/utils/lb"
+ "github.com/snail007/goproxy/utils/mapx"
+
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/conncrypt"
@@ -20,72 +27,85 @@ import (
)
type HTTPArgs struct {
- Parent *string
- CertFile *string
- KeyFile *string
- CaCertFile *string
- CaCertBytes []byte
- CertBytes []byte
- KeyBytes []byte
- Local *string
- Always *bool
- HTTPTimeout *int
- Interval *int
- Blocked *string
- Direct *string
- AuthFile *string
- Auth *[]string
- AuthURL *string
- AuthURLOkCode *int
- AuthURLTimeout *int
- AuthURLRetry *int
- ParentType *string
- LocalType *string
- Timeout *int
- CheckParentInterval *int
- SSHKeyFile *string
- SSHKeyFileSalt *string
- SSHPassword *string
- SSHUser *string
- SSHKeyBytes []byte
- SSHAuthMethod ssh.AuthMethod
- KCP kcpcfg.KCPConfigArgs
- LocalIPS *[]string
- DNSAddress *string
- DNSTTL *int
- LocalKey *string
- ParentKey *string
- LocalCompress *bool
- ParentCompress *bool
+ Parent *[]string
+ CertFile *string
+ KeyFile *string
+ CaCertFile *string
+ CaCertBytes []byte
+ CertBytes []byte
+ KeyBytes []byte
+ Local *string
+ Always *bool
+ HTTPTimeout *int
+ Interval *int
+ Blocked *string
+ Direct *string
+ AuthFile *string
+ Auth *[]string
+ AuthURL *string
+ AuthURLOkCode *int
+ AuthURLTimeout *int
+ AuthURLRetry *int
+ ParentType *string
+ LocalType *string
+ Timeout *int
+ CheckParentInterval *int
+ SSHKeyFile *string
+ SSHKeyFileSalt *string
+ SSHPassword *string
+ SSHUser *string
+ SSHKeyBytes []byte
+ SSHAuthMethod ssh.AuthMethod
+ KCP kcpcfg.KCPConfigArgs
+ LocalIPS *[]string
+ DNSAddress *string
+ DNSTTL *int
+ LocalKey *string
+ ParentKey *string
+ LocalCompress *bool
+ ParentCompress *bool
+ LoadBalanceMethod *string
+ LoadBalanceTimeout *int
+ LoadBalanceRetryTime *int
+ LoadBalanceHashTarget *bool
+ LoadBalanceOnlyHA *bool
+
+ RateLimit *string
+ RateLimitBytes float64
+ BindListen *bool
+ Debug *bool
}
type HTTP struct {
- outPool utils.OutConn
cfg HTTPArgs
checker utils.Checker
basicAuth utils.BasicAuth
sshClient *ssh.Client
lockChn chan bool
- domainResolver utils.DomainResolver
+ domainResolver dnsx.DomainResolver
isStop bool
serverChannels []*utils.ServerChannel
- userConns utils.ConcurrentMap
+ userConns mapx.ConcurrentMap
log *logger.Logger
+ lb *lb.Group
}
func NewHTTP() services.Service {
return &HTTP{
- outPool: utils.OutConn{},
cfg: HTTPArgs{},
checker: utils.Checker{},
basicAuth: utils.BasicAuth{},
lockChn: make(chan bool, 1),
isStop: false,
serverChannels: []*utils.ServerChannel{},
- userConns: utils.NewConcurrentMap(),
+ userConns: mapx.NewConcurrentMap(),
}
}
func (s *HTTP) CheckArgs() (err error) {
- if *s.cfg.Parent != "" && *s.cfg.ParentType == "" {
+
+ if len(*s.cfg.Parent) == 1 && (*s.cfg.Parent)[0] == "" {
+ (*s.cfg.Parent) = []string{}
+ }
+ if len(*s.cfg.Parent) > 0 && *s.cfg.ParentType == "" {
err = fmt.Errorf("parent type unkown,use -T ")
return
}
@@ -133,15 +153,26 @@ func (s *HTTP) CheckArgs() (err error) {
s.cfg.SSHAuthMethod = ssh.PublicKeys(SSHSigner)
}
}
+ if *s.cfg.RateLimit != "0" && *s.cfg.RateLimit != "" {
+ var size uint64
+ size, err = datasize.Parse(*s.cfg.RateLimit)
+ if err != nil {
+ err = fmt.Errorf("parse rate limit size error,ERR:%s", err)
+ return
+ }
+ s.cfg.RateLimitBytes = float64(size)
+ }
return
}
func (s *HTTP) InitService() (err error) {
s.InitBasicAuth()
- if *s.cfg.Parent != "" {
+ //init lb
+ if len(*s.cfg.Parent) > 0 {
s.checker = utils.NewChecker(*s.cfg.HTTPTimeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct, s.log)
+ s.InitLB()
}
if *s.cfg.DNSAddress != "" {
- (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log)
+ (*s).domainResolver = dnsx.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log)
}
if *s.cfg.ParentType == "ssh" {
err = s.ConnectSSH()
@@ -150,12 +181,17 @@ func (s *HTTP) InitService() (err error) {
return
}
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
//循环检查ssh网络连通性
for {
if s.isStop {
return
}
- conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
+ conn, err := utils.ConnectHost(s.Resolve(s.lb.Select("", *s.cfg.LoadBalanceOnlyHA)), *s.cfg.Timeout*2)
if err == nil {
conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
_, err = conn.Write([]byte{0})
@@ -185,11 +221,22 @@ func (s *HTTP) StopService() {
if e != nil {
s.log.Printf("stop http(s) service crashed,%s", e)
} else {
- s.log.Printf("service http(s) stopped")
+ s.log.Printf("service http(s) stoped")
}
+ s.basicAuth = utils.BasicAuth{}
+ s.cfg = HTTPArgs{}
+ s.checker = utils.Checker{}
+ s.domainResolver = dnsx.DomainResolver{}
+ s.lb = nil
+ s.lockChn = nil
+ s.log = nil
+ s.serverChannels = nil
+ s.sshClient = nil
+ s.userConns = nil
+ s = nil
}()
s.isStop = true
- if *s.cfg.Parent != "" {
+ if len(*s.cfg.Parent) > 0 {
s.checker.Stop()
}
if s.sshClient != nil {
@@ -203,6 +250,9 @@ func (s *HTTP) StopService() {
(*sc.UDPListener).Close()
}
}
+ if s.lb != nil {
+ s.lb.Stop()
+ }
}
func (s *HTTP) Start(args interface{}, log *logger.Logger) (err error) {
s.log = log
@@ -215,9 +265,8 @@ func (s *HTTP) Start(args interface{}, log *logger.Logger) (err error) {
return
}
- if *s.cfg.Parent != "" {
- s.log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent)
- s.InitOutConnPool()
+ if len(*s.cfg.Parent) > 0 {
+ s.log.Printf("use %s parent %v [ %s ]", *s.cfg.ParentType, *s.cfg.Parent, strings.ToUpper(*s.cfg.LoadBalanceMethod))
}
for _, addr := range strings.Split(*s.cfg.Local, ",") {
@@ -272,9 +321,9 @@ func (s *HTTP) callback(inConn net.Conn) {
address := req.Host
host, _, _ := net.SplitHostPort(address)
useProxy := false
- if !utils.IsIternalIP(host, *s.cfg.Always) {
+ if !utils.IsInternalIP(host, *s.cfg.Always) {
useProxy = true
- if *s.cfg.Parent == "" {
+ if len(*s.cfg.Parent) == 0 {
useProxy = false
} else if *s.cfg.Always {
useProxy = true
@@ -290,17 +339,17 @@ func (s *HTTP) callback(inConn net.Conn) {
s.log.Printf("use proxy : %v, %s", useProxy, address)
- err = s.OutToTCP(useProxy, address, &inConn, &req)
+ lbAddr, err := s.OutToTCP(useProxy, address, &inConn, &req)
if err != nil {
- if *s.cfg.Parent == "" {
+ if len(*s.cfg.Parent) == 0 {
s.log.Printf("connect to %s fail, ERR:%s", address, err)
} else {
- s.log.Printf("connect to %s parent %s fail", *s.cfg.ParentType, *s.cfg.Parent)
+ s.log.Printf("connect to %s parent %v fail", *s.cfg.ParentType, lbAddr)
}
utils.CloseConn(&inConn)
}
}
-func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *utils.HTTPRequest) (err interface{}) {
+func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *utils.HTTPRequest) (lbAddr string, err interface{}) {
inAddr := (*inConn).RemoteAddr().String()
inLocalAddr := (*inConn).LocalAddr().String()
//防止死循环
@@ -317,25 +366,26 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
return
}
if useProxy {
- if *s.cfg.ParentType == "ssh" {
- outConn, err = s.getSSHConn(address)
- } else {
- // s.log.Printf("%v", s.outPool)
- outConn, err = s.outPool.Get()
+ // s.log.Printf("%v", s.outPool)
+ selectAddr := (*inConn).RemoteAddr().String()
+ if utils.LBMethod(*s.cfg.LoadBalanceMethod) == lb.SELECT_HASH && *s.cfg.LoadBalanceHashTarget {
+ selectAddr = address
}
+ lbAddr = s.lb.Select(selectAddr, *s.cfg.LoadBalanceOnlyHA)
+ outConn, err = s.GetParentConn(lbAddr)
} else {
- outConn, err = utils.ConnectHost(s.Resolve(address), *s.cfg.Timeout)
+ outConn, err = s.GetDirectConn(s.Resolve(address), inLocalAddr)
}
tryCount++
if err == nil || tryCount > maxTryCount {
break
} else {
- s.log.Printf("connect to %s , err:%s,retrying...", *s.cfg.Parent, err)
+ s.log.Printf("connect to %s , err:%s,retrying...", lbAddr, err)
time.Sleep(time.Second * 2)
}
}
if err != nil {
- s.log.Printf("connect to %s , err:%s", *s.cfg.Parent, err)
+ s.log.Printf("connect to %s , err:%s", lbAddr, err)
utils.CloseConn(inConn)
return
}
@@ -347,6 +397,7 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
Password: *s.cfg.ParentKey,
})
}
+
outAddr := outConn.RemoteAddr().String()
//outLocalAddr := outConn.LocalAddr().String()
if req.IsHTTPS() && (!useProxy || *s.cfg.ParentType == "ssh") {
@@ -355,29 +406,39 @@ func (s *HTTP) OutToTCP(useProxy bool, address string, inConn *net.Conn, req *ut
} else {
//https或者http,上级是代理,proxy需要转发
outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- //直连目标或上级非代理或非SNI,清理HTTP头部的代理头信息.
- if (!useProxy || *s.cfg.ParentType == "ssh") && !req.IsSNI {
+ //直连目标或上级非代理或非SNI,,清理HTTP头部的代理头信息
+ if !useProxy || *s.cfg.ParentType == "ssh" && !req.IsSNI {
_, err = outConn.Write(utils.RemoveProxyHeaders(req.HeadBuf))
} else {
_, err = outConn.Write(req.HeadBuf)
}
outConn.SetDeadline(time.Time{})
if err != nil {
- s.log.Printf("write to %s , err:%s", *s.cfg.Parent, err)
+ s.log.Printf("write to %s , err:%s", lbAddr, err)
utils.CloseConn(inConn)
return
}
}
+ if s.cfg.RateLimitBytes > 0 {
+ outConn = iolimiter.NewReaderConn(outConn, s.cfg.RateLimitBytes)
+ }
+
utils.IoBind((*inConn), outConn, func(err interface{}) {
s.log.Printf("conn %s - %s released [%s]", inAddr, outAddr, req.Host)
s.userConns.Remove(inAddr)
+ if len(*s.cfg.Parent) > 0 {
+ s.lb.DecreaseConns(lbAddr)
+ }
}, s.log)
s.log.Printf("conn %s - %s connected [%s]", inAddr, outAddr, req.Host)
if c, ok := s.userConns.Get(inAddr); ok {
(*c.(*net.Conn)).Close()
}
s.userConns.Set(inAddr, inConn)
+ if len(*s.cfg.Parent) > 0 {
+ s.lb.IncreasConns(lbAddr)
+ }
return
}
@@ -434,24 +495,11 @@ func (s *HTTP) ConnectSSH() (err error) {
if s.sshClient != nil {
s.sshClient.Close()
}
- s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
+ s.sshClient, err = ssh.Dial("tcp", s.Resolve(s.lb.Select("", *s.cfg.LoadBalanceOnlyHA)), &config)
<-s.lockChn
return
}
-func (s *HTTP) InitOutConnPool() {
- if *s.cfg.ParentType == "tls" || *s.cfg.ParentType == "tcp" || *s.cfg.ParentType == "kcp" {
- //dur int, isTLS bool, certBytes, keyBytes []byte,
- //parent string, timeout int, InitialCap int, MaxCap int
- s.outPool = utils.NewOutConn(
- *s.cfg.CheckParentInterval,
- *s.cfg.ParentType,
- s.cfg.KCP,
- s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes,
- s.Resolve(*s.cfg.Parent),
- *s.cfg.Timeout,
- )
- }
-}
+
func (s *HTTP) InitBasicAuth() (err error) {
if *s.cfg.DNSAddress != "" {
s.basicAuth = utils.NewBasicAuth(&(*s).domainResolver, s.log)
@@ -477,6 +525,27 @@ func (s *HTTP) InitBasicAuth() (err error) {
}
return
}
+func (s *HTTP) InitLB() {
+ configs := lb.BackendsConfig{}
+ for _, addr := range *s.cfg.Parent {
+ _addrInfo := strings.Split(addr, "@")
+ _addr := _addrInfo[0]
+ weight := 1
+ if len(_addrInfo) == 2 {
+ weight, _ = strconv.Atoi(_addrInfo[1])
+ }
+ configs = append(configs, &lb.BackendConfig{
+ Address: _addr,
+ Weight: weight,
+ ActiveAfter: 1,
+ InactiveAfter: 2,
+ Timeout: time.Duration(*s.cfg.LoadBalanceTimeout) * time.Millisecond,
+ RetryTime: time.Duration(*s.cfg.LoadBalanceRetryTime) * time.Millisecond,
+ })
+ }
+ LB := lb.NewGroup(utils.LBMethod(*s.cfg.LoadBalanceMethod), configs, &s.domainResolver, s.log, *s.cfg.Debug)
+ s.lb = &LB
+}
func (s *HTTP) IsBasicAuth() bool {
return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 || *s.cfg.AuthURL != ""
}
@@ -494,7 +563,7 @@ func (s *HTTP) IsDeadLoop(inLocalAddr string, host string) bool {
if *s.cfg.DNSAddress != "" {
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
} else {
- outIPs, err = utils.MyLookupIP(outDomain)
+ outIPs, err = utils.LookupIP(outDomain)
}
if err == nil {
for _, ip := range outIPs {
@@ -530,3 +599,39 @@ func (s *HTTP) Resolve(address string) string {
}
return ip
}
+func (s *HTTP) GetParentConn(address string) (conn net.Conn, err error) {
+ if *s.cfg.ParentType == "tls" {
+ var _conn tls.Conn
+ _conn, err = utils.TlsConnectHost(address, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes)
+ if err == nil {
+ conn = net.Conn(&_conn)
+ }
+ } else if *s.cfg.ParentType == "kcp" {
+ conn, err = utils.ConnectKCPHost(address, s.cfg.KCP)
+ } else if *s.cfg.ParentType == "ssh" {
+ var e interface{}
+ conn, e = s.getSSHConn(address)
+ if e != nil {
+ err = fmt.Errorf("%s", e)
+ }
+ } else {
+ conn, err = utils.ConnectHost(address, *s.cfg.Timeout)
+ }
+ return
+}
+func (s *HTTP) GetDirectConn(address string, localAddr string) (conn net.Conn, err error) {
+ if !*s.cfg.BindListen {
+ return utils.ConnectHost(address, *s.cfg.Timeout)
+ }
+ ip, _, _ := net.SplitHostPort(localAddr)
+ if utils.IsInternalIP(ip, false) {
+ return utils.ConnectHost(address, *s.cfg.Timeout)
+ }
+ local, _ := net.ResolveTCPAddr("tcp", ip+":0")
+ d := net.Dialer{
+ Timeout: time.Millisecond * time.Duration(*s.cfg.Timeout),
+ LocalAddr: local,
+ }
+ conn, err = d.Dial("tcp", address)
+ return
+}
diff --git a/services/mux/mux_bridge.go b/services/mux/mux_bridge.go
index 5268151..c589e66 100644
--- a/services/mux/mux_bridge.go
+++ b/services/mux/mux_bridge.go
@@ -7,6 +7,7 @@ import (
logger "log"
"math/rand"
"net"
+ "runtime/debug"
"strconv"
"strings"
"sync"
@@ -15,15 +16,12 @@ import (
"github.com/snail007/goproxy/services"
"github.com/snail007/goproxy/services/kcpcfg"
"github.com/snail007/goproxy/utils"
+ "github.com/snail007/goproxy/utils/mapx"
+
//"github.com/xtaci/smux"
smux "github.com/hashicorp/yamux"
)
-const (
- CONN_SERVER = uint8(4)
- CONN_CLIENT = uint8(5)
-)
-
type MuxBridgeArgs struct {
CertFile *string
KeyFile *string
@@ -37,8 +35,8 @@ type MuxBridgeArgs struct {
}
type MuxBridge struct {
cfg MuxBridgeArgs
- clientControlConns utils.ConcurrentMap
- serverConns utils.ConcurrentMap
+ clientControlConns mapx.ConcurrentMap
+ serverConns mapx.ConcurrentMap
router utils.ClientKeyRouter
l *sync.Mutex
isStop bool
@@ -49,8 +47,8 @@ type MuxBridge struct {
func NewMuxBridge() services.Service {
b := &MuxBridge{
cfg: MuxBridgeArgs{},
- clientControlConns: utils.NewConcurrentMap(),
- serverConns: utils.NewConcurrentMap(),
+ clientControlConns: mapx.NewConcurrentMap(),
+ serverConns: mapx.NewConcurrentMap(),
l: &sync.Mutex{},
isStop: false,
}
@@ -80,15 +78,23 @@ func (s *MuxBridge) StopService() {
if e != nil {
s.log.Printf("stop bridge service crashed,%s", e)
} else {
- s.log.Printf("service bridge stopped")
+ s.log.Printf("service bridge stoped")
}
+ s.cfg = MuxBridgeArgs{}
+ s.clientControlConns = nil
+ s.l = nil
+ s.log = nil
+ s.router = utils.ClientKeyRouter{}
+ s.sc = nil
+ s.serverConns = nil
+ s = nil
}()
s.isStop = true
if s.sc != nil && (*s.sc).Listener != nil {
(*(*s.sc).Listener).Close()
}
for _, g := range s.clientControlConns.Items() {
- for _, session := range g.(*utils.ConcurrentMap).Items() {
+ for _, session := range g.(*mapx.ConcurrentMap).Items() {
(session.(*smux.Session)).Close()
}
}
@@ -201,17 +207,22 @@ func (s *MuxBridge) handler(inConn net.Conn) {
s.l.Lock()
defer s.l.Unlock()
if !s.clientControlConns.Has(groupKey) {
- item := utils.NewConcurrentMap()
+ item := mapx.NewConcurrentMap()
s.clientControlConns.Set(groupKey, &item)
}
_group, _ := s.clientControlConns.Get(groupKey)
- group := _group.(*utils.ConcurrentMap)
+ group := _group.(*mapx.ConcurrentMap)
if v, ok := group.Get(index); ok {
v.(*smux.Session).Close()
}
group.Set(index, session)
// s.clientControlConns.Set(key, session)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
for {
if s.isStop {
return
@@ -254,7 +265,7 @@ func (s *MuxBridge) callback(inConn net.Conn, serverID, key string) {
time.Sleep(time.Second * 3)
continue
}
- group := _group.(*utils.ConcurrentMap)
+ group := _group.(*mapx.ConcurrentMap)
keys := group.Keys()
keysLen := len(keys)
i := 0
@@ -280,10 +291,20 @@ func (s *MuxBridge) callback(inConn net.Conn, serverID, key string) {
die1 := make(chan bool, 1)
die2 := make(chan bool, 1)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
io.Copy(stream, inConn)
die1 <- true
}()
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
io.Copy(inConn, stream)
die2 <- true
}()
diff --git a/services/mux/mux_client.go b/services/mux/mux_client.go
index a40914f..de445b8 100644
--- a/services/mux/mux_client.go
+++ b/services/mux/mux_client.go
@@ -6,13 +6,17 @@ import (
"io"
logger "log"
"net"
+ "runtime/debug"
+ "strings"
"time"
- "github.com/golang/snappy"
"github.com/snail007/goproxy/services"
"github.com/snail007/goproxy/services/kcpcfg"
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/jumper"
+ "github.com/snail007/goproxy/utils/mapx"
+
+ "github.com/golang/snappy"
//"github.com/xtaci/smux"
smux "github.com/hashicorp/yamux"
)
@@ -31,23 +35,35 @@ type MuxClientArgs struct {
KCP kcpcfg.KCPConfigArgs
Jumper *string
}
+type ClientUDPConnItem struct {
+ conn *smux.Stream
+ isActive bool
+ touchtime int64
+ srcAddr *net.UDPAddr
+ localAddr *net.UDPAddr
+ udpConn *net.UDPConn
+ connid string
+}
type MuxClient struct {
cfg MuxClientArgs
isStop bool
- sessions utils.ConcurrentMap
+ sessions mapx.ConcurrentMap
log *logger.Logger
jumper *jumper.Jumper
+ udpConns mapx.ConcurrentMap
}
func NewMuxClient() services.Service {
return &MuxClient{
cfg: MuxClientArgs{},
isStop: false,
- sessions: utils.NewConcurrentMap(),
+ sessions: mapx.NewConcurrentMap(),
+ udpConns: mapx.NewConcurrentMap(),
}
}
func (s *MuxClient) InitService() (err error) {
+ s.UDPGCDeamon()
return
}
@@ -89,8 +105,14 @@ func (s *MuxClient) StopService() {
if e != nil {
s.log.Printf("stop client service crashed,%s", e)
} else {
- s.log.Printf("service client stopped")
+ s.log.Printf("service client stoped")
}
+ s.cfg = MuxClientArgs{}
+ s.jumper = nil
+ s.log = nil
+ s.sessions = nil
+ s.udpConns = nil
+ s = nil
}()
s.isStop = true
for _, sess := range s.sessions.Items() {
@@ -178,7 +200,7 @@ func (s *MuxClient) Start(args interface{}, log *logger.Logger) (err error) {
stream.Close()
return
}
- s.log.Printf("worker[%d] signal revecived,server %s stream %s %s", i, serverID, ID, clientLocalAddr)
+ //s.log.Printf("worker[%d] signal revecived,server %s stream %s %s", i, serverID, ID, clientLocalAddr)
protocol := clientLocalAddr[:3]
localAddr := clientLocalAddr[4:]
if protocol == "udp" {
@@ -228,76 +250,135 @@ func (s *MuxClient) getParentConn() (conn net.Conn, err error) {
return
}
func (s *MuxClient) ServeUDP(inConn *smux.Stream, localAddr, ID string) {
-
+ var item *ClientUDPConnItem
+ var body []byte
+ var err error
+ srcAddr := ""
+ defer func() {
+ if item != nil {
+ (*item).conn.Close()
+ (*item).udpConn.Close()
+ s.udpConns.Remove(srcAddr)
+ inConn.Close()
+ }
+ }()
for {
if s.isStop {
return
}
- inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- srcAddr, body, err := utils.ReadUDPPacket(inConn)
- inConn.SetDeadline(time.Time{})
+ srcAddr, body, err = utils.ReadUDPPacket(inConn)
if err != nil {
- s.log.Printf("udp packet revecived fail, err: %s", err)
- s.log.Printf("connection %s released", ID)
- inConn.Close()
- break
+ if strings.Contains(err.Error(), "n != int(") {
+ continue
+ }
+ if !utils.IsNetDeadlineErr(err) && err != io.EOF {
+ s.log.Printf("udp packet revecived from bridge fail, err: %s", err)
+ }
+ return
+ }
+ if v, ok := s.udpConns.Get(srcAddr); !ok {
+ _srcAddr, _ := net.ResolveUDPAddr("udp", srcAddr)
+ zeroAddr, _ := net.ResolveUDPAddr("udp", ":")
+ _localAddr, _ := net.ResolveUDPAddr("udp", localAddr)
+ c, err := net.DialUDP("udp", zeroAddr, _localAddr)
+ if err != nil {
+ s.log.Printf("create local udp conn fail, err : %s", err)
+ inConn.Close()
+ return
+ }
+ item = &ClientUDPConnItem{
+ conn: inConn,
+ srcAddr: _srcAddr,
+ localAddr: _localAddr,
+ udpConn: c,
+ connid: ID,
+ }
+ s.udpConns.Set(srcAddr, item)
+ s.UDPRevecive(srcAddr, ID)
} else {
- //s.log.Printf("udp packet revecived:%s,%v", srcAddr, body)
+ item = v.(*ClientUDPConnItem)
+ }
+ (*item).touchtime = time.Now().Unix()
+ go (*item).udpConn.Write(body)
+ }
+}
+func (s *MuxClient) UDPRevecive(key, ID string) {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.log.Printf("udp conn %s connected", ID)
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID)
+ return
+ }
+ cui := v.(*ClientUDPConnItem)
+ buf := utils.LeakyBuffer.Get()
+ defer func() {
+ utils.LeakyBuffer.Put(buf)
+ cui.conn.Close()
+ cui.udpConn.Close()
+ s.udpConns.Remove(key)
+ s.log.Printf("udp conn %s released", ID)
+ }()
+ for {
+ n, err := cui.udpConn.Read(buf)
+ if err != nil {
+ if !utils.IsNetClosedErr(err) {
+ s.log.Printf("udp conn read udp packet fail , err: %s ", err)
+ }
+ return
+ }
+ cui.touchtime = time.Now().Unix()
go func() {
defer func() {
if e := recover(); e != nil {
- s.log.Printf("client processUDPPacket crashed,err: %s", e)
+ fmt.Printf("crashed:%s", string(debug.Stack()))
}
}()
- s.processUDPPacket(inConn, srcAddr, localAddr, body)
+ cui.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ _, err = cui.conn.Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n]))
+ cui.conn.SetWriteDeadline(time.Time{})
+ if err != nil {
+ cui.udpConn.Close()
+ return
+ }
}()
-
}
-
- }
- // }
+ }()
}
-func (s *MuxClient) processUDPPacket(inConn *smux.Stream, srcAddr, localAddr string, body []byte) {
- dstAddr, err := net.ResolveUDPAddr("udp", localAddr)
- if err != nil {
- s.log.Printf("can't resolve address: %s", err)
- inConn.Close()
- return
- }
- clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
- conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
- if err != nil {
- s.log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
- return
- }
- conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = conn.Write(body)
- conn.SetDeadline(time.Time{})
- if err != nil {
- s.log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
- return
- }
- //s.log.Printf("send udp packet to %s success", dstAddr.String())
- buf := make([]byte, 1024)
- conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- length, _, err := conn.ReadFromUDP(buf)
- conn.SetDeadline(time.Time{})
- if err != nil {
- s.log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
- return
- }
- respBody := buf[0:length]
- //s.log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody)
- bs := utils.UDPPacket(srcAddr, respBody)
- (*inConn).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = (*inConn).Write(bs)
- (*inConn).SetDeadline(time.Time{})
- if err != nil {
- s.log.Printf("send udp response fail ,ERR:%s", err)
- inConn.Close()
- return
- }
- //s.log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs)
+func (s *MuxClient) UDPGCDeamon() {
+ gctime := int64(30)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ if s.isStop {
+ return
+ }
+ timer := time.NewTicker(time.Second)
+ for {
+ <-timer.C
+ gcKeys := []string{}
+ s.udpConns.IterCb(func(key string, v interface{}) {
+ if time.Now().Unix()-v.(*ClientUDPConnItem).touchtime > gctime {
+ (*(v.(*ClientUDPConnItem).conn)).Close()
+ (v.(*ClientUDPConnItem).udpConn).Close()
+ gcKeys = append(gcKeys, key)
+ s.log.Printf("gc udp conn %s", v.(*ClientUDPConnItem).connid)
+ }
+ })
+ for _, k := range gcKeys {
+ s.udpConns.Remove(k)
+ }
+ gcKeys = nil
+ }
+ }()
}
func (s *MuxClient) ServeConn(inConn *smux.Stream, localAddr, ID string) {
var err error
@@ -331,10 +412,20 @@ func (s *MuxClient) ServeConn(inConn *smux.Stream, localAddr, ID string) {
die1 := make(chan bool, 1)
die2 := make(chan bool, 1)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
io.Copy(outConn, snappy.NewReader(inConn))
die1 <- true
}()
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
io.Copy(snappy.NewWriter(inConn), outConn)
die2 <- true
}()
diff --git a/services/mux/mux_server.go b/services/mux/mux_server.go
index b9321fe..bb511e4 100644
--- a/services/mux/mux_server.go
+++ b/services/mux/mux_server.go
@@ -16,12 +16,19 @@ import (
"github.com/snail007/goproxy/services/kcpcfg"
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/jumper"
+ "github.com/snail007/goproxy/utils/mapx"
"github.com/golang/snappy"
//"github.com/xtaci/smux"
smux "github.com/hashicorp/yamux"
)
+const (
+ CONN_CLIENT_CONTROL = uint8(1)
+ CONN_SERVER = uint8(4)
+ CONN_CLIENT = uint8(5)
+)
+
type MuxServerArgs struct {
Parent *string
ParentType *string
@@ -41,16 +48,19 @@ type MuxServerArgs struct {
KCP kcpcfg.KCPConfigArgs
Jumper *string
}
-
-type MuxUDPItem struct {
- packet *[]byte
- localAddr *net.UDPAddr
- srcAddr *net.UDPAddr
+type MuxServer struct {
+ cfg MuxServerArgs
+ sc utils.ServerChannel
+ sessions mapx.ConcurrentMap
+ lockChn chan bool
+ isStop bool
+ log *logger.Logger
+ jumper *jumper.Jumper
+ udpConns mapx.ConcurrentMap
}
type MuxServerManager struct {
cfg MuxServerArgs
- udpChn chan MuxUDPItem
serverID string
servers []*services.Service
log *logger.Logger
@@ -59,7 +69,6 @@ type MuxServerManager struct {
func NewMuxServerManager() services.Service {
return &MuxServerManager{
cfg: MuxServerArgs{},
- udpChn: make(chan MuxUDPItem, 50000),
serverID: utils.Uniqueid(),
servers: []*services.Service{},
}
@@ -140,6 +149,11 @@ func (s *MuxServerManager) StopService() {
for _, server := range s.servers {
(*server).Clean()
}
+ s.cfg = MuxServerArgs{}
+ s.log = nil
+ s.serverID = ""
+ s.servers = nil
+ s = nil
}
func (s *MuxServerManager) CheckArgs() (err error) {
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
@@ -158,36 +172,40 @@ func (s *MuxServerManager) InitService() (err error) {
return
}
-type MuxServer struct {
- cfg MuxServerArgs
- udpChn chan MuxUDPItem
- sc utils.ServerChannel
- sessions utils.ConcurrentMap
- lockChn chan bool
- isStop bool
- udpConn *net.Conn
- log *logger.Logger
- jumper *jumper.Jumper
-}
-
func NewMuxServer() services.Service {
return &MuxServer{
cfg: MuxServerArgs{},
- udpChn: make(chan MuxUDPItem, 50000),
lockChn: make(chan bool, 1),
- sessions: utils.NewConcurrentMap(),
+ sessions: mapx.NewConcurrentMap(),
isStop: false,
+ udpConns: mapx.NewConcurrentMap(),
}
}
+type MuxUDPConnItem struct {
+ conn *net.Conn
+ touchtime int64
+ srcAddr *net.UDPAddr
+ localAddr *net.UDPAddr
+ connid string
+}
+
func (s *MuxServer) StopService() {
defer func() {
e := recover()
if e != nil {
s.log.Printf("stop server service crashed,%s", e)
} else {
- s.log.Printf("service server stopped")
+ s.log.Printf("service server stoped")
}
+ s.cfg = MuxServerArgs{}
+ s.jumper = nil
+ s.lockChn = nil
+ s.log = nil
+ s.sc = utils.ServerChannel{}
+ s.sessions = nil
+ s.udpConns = nil
+ s = nil
}()
s.isStop = true
for _, sess := range s.sessions.Items() {
@@ -199,12 +217,9 @@ func (s *MuxServer) StopService() {
if s.sc.UDPListener != nil {
(*s.sc.UDPListener).Close()
}
- if s.udpConn != nil {
- (*s.udpConn).Close()
- }
}
func (s *MuxServer) InitService() (err error) {
- s.UDPConnDeamon()
+ s.UDPGCDeamon()
return
}
func (s *MuxServer) CheckArgs() (err error) {
@@ -241,12 +256,8 @@ func (s *MuxServer) Start(args interface{}, log *logger.Logger) (err error) {
p, _ := strconv.Atoi(port)
s.sc = utils.NewServerChannel(host, p, s.log)
if *s.cfg.IsUDP {
- err = s.sc.ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) {
- s.udpChn <- MuxUDPItem{
- packet: &packet,
- localAddr: localAddr,
- srcAddr: srcAddr,
- }
+ err = s.sc.ListenUDP(func(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr) {
+ s.UDPSend(packet, localAddr, srcAddr)
})
if err != nil {
return
@@ -280,10 +291,20 @@ func (s *MuxServer) Start(args interface{}, log *logger.Logger) (err error) {
die1 := make(chan bool, 1)
die2 := make(chan bool, 1)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
io.Copy(inConn, snappy.NewReader(outConn))
die1 <- true
}()
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
io.Copy(snappy.NewWriter(outConn), inConn)
die2 <- true
}()
@@ -317,7 +338,9 @@ func (s *MuxServer) GetOutConn() (outConn net.Conn, ID string, err error) {
}
outConn, err = s.GetConn(fmt.Sprintf("%d", i))
if err != nil {
- s.log.Printf("connection err: %s", err)
+ if !strings.Contains(err.Error(), "can not connect at same time") {
+ s.log.Printf("connection err: %s", err)
+ }
return
}
remoteAddr := "tcp:" + *s.cfg.Remote
@@ -372,6 +395,11 @@ func (s *MuxServer) GetConn(index string) (conn net.Conn, err error) {
s.sessions.Set(index, session)
s.log.Printf("session[%s] created", index)
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
for {
if s.isStop {
return
@@ -424,87 +452,128 @@ func (s *MuxServer) getParentConn() (conn net.Conn, err error) {
}
return
}
-func (s *MuxServer) UDPConnDeamon() {
+func (s *MuxServer) UDPGCDeamon() {
+ gctime := int64(30)
go func() {
defer func() {
- if err := recover(); err != nil {
- s.log.Printf("udp conn deamon crashed with err : %s \nstack: %s", err, string(debug.Stack()))
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
}
}()
- var outConn net.Conn
- var ID string
- var err error
+ if s.isStop {
+ return
+ }
+ timer := time.NewTicker(time.Second)
for {
- if s.isStop {
- return
- }
- item := <-s.udpChn
- RETRY:
- if s.isStop {
- return
- }
- if outConn == nil {
- for {
- if s.isStop {
- return
- }
- outConn, ID, err = s.GetOutConn()
- if err != nil {
- outConn = nil
- utils.CloseConn(&outConn)
- s.log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err)
- time.Sleep(time.Second * 3)
- continue
- } else {
- go func(outConn net.Conn, ID string) {
- if s.udpConn != nil {
- (*s.udpConn).Close()
- }
- s.udpConn = &outConn
- for {
- if s.isStop {
- return
- }
- outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- srcAddrFromConn, body, err := utils.ReadUDPPacket(outConn)
- outConn.SetDeadline(time.Time{})
- if err != nil {
- s.log.Printf("parse revecived udp packet fail, err: %s ,%v", err, body)
- s.log.Printf("UDP deamon connection %s exited", ID)
- break
- }
- //s.log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn)
- _srcAddr := strings.Split(srcAddrFromConn, ":")
- if len(_srcAddr) != 2 {
- s.log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn)
- continue
- }
- port, _ := strconv.Atoi(_srcAddr[1])
- dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port}
- s.sc.UDPListener.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = s.sc.UDPListener.WriteToUDP(body, dstAddr)
- s.sc.UDPListener.SetDeadline(time.Time{})
- if err != nil {
- s.log.Printf("udp response to local %s fail,ERR:%s", srcAddrFromConn, err)
- continue
- }
- //s.log.Printf("udp response to local %s success , %v", srcAddrFromConn, body)
- }
- }(outConn, ID)
- break
- }
+ <-timer.C
+ gcKeys := []string{}
+ s.udpConns.IterCb(func(key string, v interface{}) {
+ if time.Now().Unix()-v.(*MuxUDPConnItem).touchtime > gctime {
+ (*(v.(*MuxUDPConnItem).conn)).Close()
+ gcKeys = append(gcKeys, key)
+ s.log.Printf("gc udp conn %s", v.(*MuxUDPConnItem).connid)
}
+ })
+ for _, k := range gcKeys {
+ s.udpConns.Remove(k)
}
- outConn.SetWriteDeadline(time.Now().Add(time.Second))
- _, err = outConn.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet))
- outConn.SetWriteDeadline(time.Time{})
- if err != nil {
- utils.CloseConn(&outConn)
- outConn = nil
- s.log.Printf("write udp packet to %s fail ,flush err:%s ,retrying...", *s.cfg.Parent, err)
- goto RETRY
- }
- //s.log.Printf("write packet %v", *item.packet)
+ gcKeys = nil
+ }
+ }()
+}
+func (s *MuxServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) {
+ var (
+ uc *MuxUDPConnItem
+ key = srcAddr.String()
+ ID string
+ err error
+ outconn net.Conn
+ )
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ for {
+ outconn, ID, err = s.GetOutConn()
+ if err != nil && strings.Contains(err.Error(), "can not connect at same time") {
+ time.Sleep(time.Millisecond * 500)
+ continue
+ } else {
+ break
+ }
+ }
+ if err != nil {
+ s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err)
+ return
+ }
+ uc = &MuxUDPConnItem{
+ conn: &outconn,
+ srcAddr: srcAddr,
+ localAddr: localAddr,
+ connid: ID,
+ }
+ s.udpConns.Set(key, uc)
+ s.UDPRevecive(key, ID)
+ } else {
+ uc = v.(*MuxUDPConnItem)
+ }
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ (*uc.conn).Close()
+ s.udpConns.Remove(key)
+ s.log.Printf("udp sender crashed with error : %s", e)
+ }
+ }()
+ uc.touchtime = time.Now().Unix()
+ (*uc.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ _, err = (*uc.conn).Write(utils.UDPPacket(srcAddr.String(), data))
+ (*uc.conn).SetWriteDeadline(time.Time{})
+ if err != nil {
+ s.log.Printf("write udp packet to %s fail ,flush err:%s ", *s.cfg.Parent, err)
+ }
+ }()
+}
+func (s *MuxServer) UDPRevecive(key, ID string) {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.log.Printf("udp conn %s connected", ID)
+ var uc *MuxUDPConnItem
+ defer func() {
+ if uc != nil {
+ (*uc.conn).Close()
+ }
+ s.udpConns.Remove(key)
+ s.log.Printf("udp conn %s released", ID)
+ }()
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID)
+ return
+ }
+ uc = v.(*MuxUDPConnItem)
+ for {
+ _, body, err := utils.ReadUDPPacket(*uc.conn)
+ if err != nil {
+ if strings.Contains(err.Error(), "n != int(") {
+ continue
+ }
+ if err != io.EOF {
+ s.log.Printf("udp conn read udp packet fail , err: %s ", err)
+ }
+ return
+ }
+ uc.touchtime = time.Now().Unix()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.sc.UDPListener.WriteToUDP(body, uc.srcAddr)
+ }()
}
}()
}
diff --git a/services/service.go b/services/service.go
index d126642..1be1fce 100644
--- a/services/service.go
+++ b/services/service.go
@@ -39,6 +39,7 @@ func GetService(name string) *ServiceItem {
func Stop(name string) {
if s, ok := servicesMap.Load(name); ok && s.(*ServiceItem).S != nil {
s.(*ServiceItem).S.Clean()
+ servicesMap.Delete(name)
}
}
func Run(name string, args interface{}) (service *ServiceItem, err error) {
diff --git a/services/socks/socks.go b/services/socks/socks.go
index b7856f2..1ca388b 100644
--- a/services/socks/socks.go
+++ b/services/socks/socks.go
@@ -8,6 +8,7 @@ import (
logger "log"
"net"
"runtime/debug"
+ "strconv"
"strings"
"time"
@@ -15,46 +16,63 @@ import (
"github.com/snail007/goproxy/services/kcpcfg"
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/conncrypt"
+ "github.com/snail007/goproxy/utils/datasize"
+ "github.com/snail007/goproxy/utils/dnsx"
+ "github.com/snail007/goproxy/utils/iolimiter"
+ "github.com/snail007/goproxy/utils/lb"
+ "github.com/snail007/goproxy/utils/mapx"
"github.com/snail007/goproxy/utils/socks"
+
"golang.org/x/crypto/ssh"
)
type SocksArgs struct {
- Parent *string
- ParentType *string
- Local *string
- LocalType *string
- CertFile *string
- KeyFile *string
- CaCertFile *string
- CaCertBytes []byte
- CertBytes []byte
- KeyBytes []byte
- SSHKeyFile *string
- SSHKeyFileSalt *string
- SSHPassword *string
- SSHUser *string
- SSHKeyBytes []byte
- SSHAuthMethod ssh.AuthMethod
- Timeout *int
- Always *bool
- Interval *int
- Blocked *string
- Direct *string
- AuthFile *string
- Auth *[]string
- AuthURL *string
- AuthURLOkCode *int
- AuthURLTimeout *int
- AuthURLRetry *int
- KCP kcpcfg.KCPConfigArgs
- LocalIPS *[]string
- DNSAddress *string
- DNSTTL *int
- LocalKey *string
- ParentKey *string
- LocalCompress *bool
- ParentCompress *bool
+ Parent *[]string
+ ParentType *string
+ Local *string
+ LocalType *string
+ CertFile *string
+ KeyFile *string
+ CaCertFile *string
+ CaCertBytes []byte
+ CertBytes []byte
+ KeyBytes []byte
+ SSHKeyFile *string
+ SSHKeyFileSalt *string
+ SSHPassword *string
+ SSHUser *string
+ SSHKeyBytes []byte
+ SSHAuthMethod ssh.AuthMethod
+ Timeout *int
+ Always *bool
+ Interval *int
+ Blocked *string
+ Direct *string
+ ParentAuth *string
+ AuthFile *string
+ Auth *[]string
+ AuthURL *string
+ AuthURLOkCode *int
+ AuthURLTimeout *int
+ AuthURLRetry *int
+ KCP kcpcfg.KCPConfigArgs
+ LocalIPS *[]string
+ DNSAddress *string
+ DNSTTL *int
+ LocalKey *string
+ ParentKey *string
+ LocalCompress *bool
+ ParentCompress *bool
+ LoadBalanceMethod *string
+ LoadBalanceTimeout *int
+ LoadBalanceRetryTime *int
+ LoadBalanceHashTarget *bool
+ LoadBalanceOnlyHA *bool
+
+ RateLimit *string
+ RateLimitBytes float64
+ BindListen *bool
+ Debug *bool
}
type Socks struct {
cfg SocksArgs
@@ -64,11 +82,12 @@ type Socks struct {
lockChn chan bool
udpSC utils.ServerChannel
sc *utils.ServerChannel
- domainResolver utils.DomainResolver
+ domainResolver dnsx.DomainResolver
isStop bool
- userConns utils.ConcurrentMap
+ userConns mapx.ConcurrentMap
log *logger.Logger
- udpRelatedPacketConns utils.ConcurrentMap
+ lb *lb.Group
+ udpRelatedPacketConns mapx.ConcurrentMap
udpLocalKey []byte
udpParentKey []byte
}
@@ -80,14 +99,14 @@ func NewSocks() services.Service {
basicAuth: utils.BasicAuth{},
lockChn: make(chan bool, 1),
isStop: false,
- userConns: utils.NewConcurrentMap(),
- udpRelatedPacketConns: utils.NewConcurrentMap(),
+ userConns: mapx.NewConcurrentMap(),
+ udpRelatedPacketConns: mapx.NewConcurrentMap(),
}
}
func (s *Socks) CheckArgs() (err error) {
- if *s.cfg.LocalType == "tls" || (*s.cfg.Parent != "" && *s.cfg.ParentType == "tls") {
+ if *s.cfg.LocalType == "tls" || (len(*s.cfg.Parent) > 0 && *s.cfg.ParentType == "tls") {
s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
if err != nil {
return
@@ -100,7 +119,12 @@ func (s *Socks) CheckArgs() (err error) {
}
}
}
- if *s.cfg.Parent != "" {
+
+ if len(*s.cfg.Parent) == 1 && (*s.cfg.Parent)[0] == "" {
+ (*s.cfg.Parent) = []string{}
+ }
+
+ if len(*s.cfg.Parent) > 0 {
if *s.cfg.ParentType == "" {
err = fmt.Errorf("parent type unkown,use -T ")
return
@@ -136,32 +160,46 @@ func (s *Socks) CheckArgs() (err error) {
}
}
}
+ if *s.cfg.RateLimit != "0" && *s.cfg.RateLimit != "" {
+ var size uint64
+ size, err = datasize.Parse(*s.cfg.RateLimit)
+ if err != nil {
+ err = fmt.Errorf("parse rate limit size error,ERR:%s", err)
+ return
+ }
+ s.cfg.RateLimitBytes = float64(size)
+ }
s.udpLocalKey = s.LocalUDPKey()
s.udpParentKey = s.ParentUDPKey()
- //s.log.Printf("udpLocalKey : %v , udpParentKey : %v", s.udpLocalKey, s.udpParentKey)
return
}
func (s *Socks) InitService() (err error) {
s.InitBasicAuth()
if *s.cfg.DNSAddress != "" {
- (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log)
+ (*s).domainResolver = dnsx.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log)
}
- if *s.cfg.Parent != "" {
+ if len(*s.cfg.Parent) > 0 {
s.checker = utils.NewChecker(*s.cfg.Timeout, int64(*s.cfg.Interval), *s.cfg.Blocked, *s.cfg.Direct, s.log)
+ s.InitLB()
}
if *s.cfg.ParentType == "ssh" {
- e := s.ConnectSSH()
+ e := s.ConnectSSH(s.Resolve(s.lb.Select("", *s.cfg.LoadBalanceOnlyHA)))
if e != nil {
err = fmt.Errorf("init service fail, ERR: %s", e)
return
}
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
//循环检查ssh网络连通性
for {
if s.isStop {
return
}
- conn, err := utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout*2)
+ conn, err := utils.ConnectHost(s.Resolve(s.lb.Select("", *s.cfg.LoadBalanceOnlyHA)), *s.cfg.Timeout*2)
if err == nil {
conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
_, err = conn.Write([]byte{0})
@@ -172,7 +210,7 @@ func (s *Socks) InitService() (err error) {
s.sshClient.Close()
}
s.log.Printf("ssh offline, retrying...")
- s.ConnectSSH()
+ s.ConnectSSH(s.Resolve(s.lb.Select("", *s.cfg.LoadBalanceOnlyHA)))
} else {
conn.Close()
}
@@ -191,11 +229,26 @@ func (s *Socks) StopService() {
if e != nil {
s.log.Printf("stop socks service crashed,%s", e)
} else {
- s.log.Printf("service socks stopped")
+ s.log.Printf("service socks stoped")
}
+ s.basicAuth = utils.BasicAuth{}
+ s.cfg = SocksArgs{}
+ s.checker = utils.Checker{}
+ s.domainResolver = dnsx.DomainResolver{}
+ s.lb = nil
+ s.lockChn = nil
+ s.log = nil
+ s.sc = nil
+ s.sshClient = nil
+ s.udpLocalKey = nil
+ s.udpParentKey = nil
+ s.udpRelatedPacketConns = nil
+ s.udpSC = utils.ServerChannel{}
+ s.userConns = nil
+ s = nil
}()
s.isStop = true
- if *s.cfg.Parent != "" {
+ if len(*s.cfg.Parent) > 0 {
s.checker.Stop()
}
if s.sshClient != nil {
@@ -210,6 +263,9 @@ func (s *Socks) StopService() {
for _, c := range s.userConns.Items() {
(*c.(*net.Conn)).Close()
}
+ if s.lb != nil {
+ s.lb.Stop()
+ }
for _, c := range s.udpRelatedPacketConns.Items() {
(*c.(*net.UDPConn)).Close()
}
@@ -224,14 +280,14 @@ func (s *Socks) Start(args interface{}, log *logger.Logger) (err error) {
if err = s.InitService(); err != nil {
s.InitService()
}
- if *s.cfg.Parent != "" {
- s.log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent)
+ if len(*s.cfg.Parent) > 0 {
+ s.log.Printf("use %s parent %v [ %s ]", *s.cfg.ParentType, *s.cfg.Parent, strings.ToUpper(*s.cfg.LoadBalanceMethod))
}
sc := utils.NewServerChannelHost(*s.cfg.Local, s.log)
if *s.cfg.LocalType == "tcp" {
err = sc.ListenTCP(s.socksConnCallback)
} else if *s.cfg.LocalType == "tls" {
- err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.socksConnCallback)
+ err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes, s.socksConnCallback)
} else if *s.cfg.LocalType == "kcp" {
err = sc.ListenKCP(s.cfg.KCP, s.socksConnCallback, s.log)
}
@@ -261,116 +317,40 @@ func (s *Socks) socksConnCallback(inConn net.Conn) {
Password: *s.cfg.LocalKey,
})
}
- //协商开始
- //method select request
- inConn.SetReadDeadline(time.Now().Add(time.Second * 3))
- methodReq, err := socks.NewMethodsRequest(inConn)
- inConn.SetReadDeadline(time.Time{})
- if err != nil {
- methodReq.Reply(socks.Method_NONE_ACCEPTABLE)
- utils.CloseConn(&inConn)
- if err != io.EOF {
- s.log.Printf("new methods request fail,ERR: %s", err)
- }
- return
- }
-
- if !s.IsBasicAuth() {
- if !methodReq.Select(socks.Method_NO_AUTH) {
- methodReq.Reply(socks.Method_NONE_ACCEPTABLE)
- utils.CloseConn(&inConn)
- s.log.Printf("none method found : Method_NO_AUTH")
- return
- }
- //method select reply
- err = methodReq.Reply(socks.Method_NO_AUTH)
- if err != nil {
- s.log.Printf("reply answer data fail,ERR: %s", err)
- utils.CloseConn(&inConn)
- return
- }
- // s.log.Printf("% x", methodReq.Bytes())
+ //socks5 server
+ var serverConn *socks.ServerConn
+ udpIP, _, _ := net.SplitHostPort(inConn.LocalAddr().String())
+ if s.IsBasicAuth() {
+ serverConn = socks.NewServerConn(&inConn, time.Millisecond*time.Duration(*s.cfg.Timeout), &s.basicAuth, true, udpIP, nil)
} else {
- //auth
- if !methodReq.Select(socks.Method_USER_PASS) {
- methodReq.Reply(socks.Method_NONE_ACCEPTABLE)
- utils.CloseConn(&inConn)
- s.log.Printf("none method found : Method_USER_PASS")
- return
- }
- //method reply need auth
- err = methodReq.Reply(socks.Method_USER_PASS)
- if err != nil {
- s.log.Printf("reply answer data fail,ERR: %s", err)
- utils.CloseConn(&inConn)
- return
- }
- //read auth
- buf := make([]byte, 500)
- inConn.SetReadDeadline(time.Now().Add(time.Second * 3))
- n, err := inConn.Read(buf)
- inConn.SetReadDeadline(time.Time{})
- if err != nil {
- utils.CloseConn(&inConn)
- return
- }
- r := buf[:n]
- user := string(r[2 : r[1]+2])
- pass := string(r[2+r[1]+1:])
- //s.log.Printf("user:%s,pass:%s", user, pass)
- //auth
- _addr := strings.Split(inConn.RemoteAddr().String(), ":")
- if s.basicAuth.CheckUserPass(user, pass, _addr[0], "") {
- inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- inConn.Write([]byte{0x01, 0x00})
- inConn.SetDeadline(time.Time{})
-
- } else {
- inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- inConn.Write([]byte{0x01, 0x01})
- inConn.SetDeadline(time.Time{})
-
- utils.CloseConn(&inConn)
- return
- }
+ serverConn = socks.NewServerConn(&inConn, time.Millisecond*time.Duration(*s.cfg.Timeout), nil, true, udpIP, nil)
}
-
- //request detail
- request, err := socks.NewRequest(inConn)
- if err != nil {
- s.log.Printf("read request data fail,ERR: %s", err)
- utils.CloseConn(&inConn)
+ if err := serverConn.Handshake(); err != nil {
+ if !strings.HasSuffix(err.Error(), "EOF") {
+ s.log.Printf("handshake fail, ERR: %s", err)
+ }
+ inConn.Close()
return
}
- //协商结束
-
- switch request.CMD() {
- case socks.CMD_BIND:
- //bind 不支持
- request.TCPReply(socks.REP_UNKNOWN)
- utils.CloseConn(&inConn)
- return
- case socks.CMD_CONNECT:
- //tcp
- s.proxyTCP(&inConn, methodReq, request)
- case socks.CMD_ASSOCIATE:
- //udp
- s.proxyUDP(&inConn, methodReq, request)
+ if serverConn.IsUDP() {
+ s.proxyUDP(&inConn, serverConn)
+ } else if serverConn.IsTCP() {
+ s.proxyTCP(&inConn, serverConn)
}
-
}
-func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) {
+func (s *Socks) proxyTCP(inConn *net.Conn, serverConn *socks.ServerConn) {
var outConn net.Conn
var err interface{}
+ lbAddr := ""
useProxy := true
tryCount := 0
maxTryCount := 5
//防止死循环
- if s.IsDeadLoop((*inConn).LocalAddr().String(), request.Host()) {
+ if s.IsDeadLoop((*inConn).LocalAddr().String(), serverConn.Host()) {
utils.CloseConn(inConn)
- s.log.Printf("dead loop detected , %s", request.Host())
+ s.log.Printf("dead loop detected , %s", serverConn.Host())
utils.CloseConn(inConn)
return
}
@@ -378,33 +358,70 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
if s.isStop {
return
}
+
if *s.cfg.Always {
- outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr(), true)
+ selectAddr := (*inConn).RemoteAddr().String()
+ if utils.LBMethod(*s.cfg.LoadBalanceMethod) == lb.SELECT_HASH && *s.cfg.LoadBalanceHashTarget {
+ selectAddr = serverConn.Target()
+ }
+ lbAddr = s.lb.Select(selectAddr, *s.cfg.LoadBalanceOnlyHA)
+ //lbAddr = s.lb.Select((*inConn).RemoteAddr().String())
+ outConn, err = s.GetParentConn(lbAddr, serverConn)
+ if err != nil {
+ s.log.Printf("connect to parent fail, %s", err)
+ return
+ }
+ //handshake
+ //socks client
+ _, err = s.HandshakeSocksParent(&outConn, "tcp", serverConn.Target(), serverConn.AuthData(), false)
+ if err != nil {
+ if err != io.EOF {
+ s.log.Printf("handshake fail, %s", err)
+ }
+ return
+ }
} else {
- if *s.cfg.Parent != "" {
- host, _, _ := net.SplitHostPort(request.Addr())
+ if len(*s.cfg.Parent) > 0 {
+ host, _, _ := net.SplitHostPort(serverConn.Target())
useProxy := false
- if utils.IsIternalIP(host, *s.cfg.Always) {
+ if utils.IsInternalIP(host, *s.cfg.Always) {
useProxy = false
} else {
var isInMap bool
- useProxy, isInMap, _, _ = s.checker.IsBlocked(request.Addr())
+ useProxy, isInMap, _, _ = s.checker.IsBlocked(serverConn.Target())
if !isInMap {
- s.checker.Add(request.Addr(), s.Resolve(request.Addr()))
+ s.checker.Add(serverConn.Target(), s.Resolve(serverConn.Target()))
}
}
if useProxy {
- outConn, err = s.getOutConn(methodReq.Bytes(), request.Bytes(), request.Addr(), true)
+ selectAddr := (*inConn).RemoteAddr().String()
+ if utils.LBMethod(*s.cfg.LoadBalanceMethod) == lb.SELECT_HASH && *s.cfg.LoadBalanceHashTarget {
+ selectAddr = serverConn.Target()
+ }
+ lbAddr = s.lb.Select(selectAddr, *s.cfg.LoadBalanceOnlyHA)
+ //lbAddr = s.lb.Select((*inConn).RemoteAddr().String())
+ outConn, err = s.GetParentConn(lbAddr, serverConn)
+ if err != nil {
+ s.log.Printf("connect to parent fail, %s", err)
+ return
+ }
+ //handshake
+ //socks client
+ _, err = s.HandshakeSocksParent(&outConn, "tcp", serverConn.Target(), serverConn.AuthData(), false)
+ if err != nil {
+ s.log.Printf("handshake fail, %s", err)
+ return
+ }
} else {
- outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
+ outConn, err = s.GetDirectConn(s.Resolve(serverConn.Target()), (*inConn).LocalAddr().String())
}
} else {
- outConn, err = utils.ConnectHost(s.Resolve(request.Addr()), *s.cfg.Timeout)
+ outConn, err = s.GetDirectConn(s.Resolve(serverConn.Target()), (*inConn).LocalAddr().String())
useProxy = false
}
}
tryCount++
- if err == nil || tryCount > maxTryCount || *s.cfg.Parent == "" {
+ if err == nil || tryCount > maxTryCount || len(*s.cfg.Parent) == 0 {
break
} else {
s.log.Printf("get out conn fail,%s,retrying...", err)
@@ -413,28 +430,37 @@ func (s *Socks) proxyTCP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
}
if err != nil {
s.log.Printf("get out conn fail,%s", err)
- request.TCPReply(socks.REP_NETWOR_UNREACHABLE)
return
}
- s.log.Printf("use proxy %v : %s", useProxy, request.Addr())
+ s.log.Printf("use proxy %v : %s", useProxy, serverConn.Target())
- request.TCPReply(socks.REP_SUCCESS)
inAddr := (*inConn).RemoteAddr().String()
+ //outRemoteAddr := outConn.RemoteAddr().String()
//inLocalAddr := (*inConn).LocalAddr().String()
- s.log.Printf("conn %s - %s connected", inAddr, request.Addr())
+ if s.cfg.RateLimitBytes > 0 {
+ outConn = iolimiter.NewReaderConn(outConn, s.cfg.RateLimitBytes)
+ }
+
utils.IoBind(*inConn, outConn, func(err interface{}) {
- s.log.Printf("conn %s - %s released", inAddr, request.Addr())
+ s.log.Printf("conn %s - %s released", inAddr, serverConn.Target())
s.userConns.Remove(inAddr)
+ if len(*s.cfg.Parent) > 0 {
+ s.lb.DecreaseConns(lbAddr)
+ }
}, s.log)
if c, ok := s.userConns.Get(inAddr); ok {
(*c.(*net.Conn)).Close()
s.userConns.Remove(inAddr)
}
s.userConns.Set(inAddr, inConn)
+ if len(*s.cfg.Parent) > 0 {
+ s.lb.IncreasConns(lbAddr)
+ }
+ s.log.Printf("conn %s - %s connected", inAddr, serverConn.Target())
}
-func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string, handshake bool) (outConn net.Conn, err interface{}) {
+func (s *Socks) GetParentConn(parentAddress string, serverConn *socks.ServerConn) (outConn net.Conn, err interface{}) {
switch *s.cfg.ParentType {
case "kcp":
fallthrough
@@ -442,13 +468,15 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string, handshake
fallthrough
case "tcp":
if *s.cfg.ParentType == "tls" {
- var _outConn tls.Conn
- _outConn, err = utils.TlsConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
- outConn = net.Conn(&_outConn)
+ var _conn tls.Conn
+ _conn, err = utils.TlsConnectHost(parentAddress, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes)
+ if err == nil {
+ outConn = net.Conn(&_conn)
+ }
} else if *s.cfg.ParentType == "kcp" {
- outConn, err = utils.ConnectKCPHost(s.Resolve(*s.cfg.Parent), s.cfg.KCP)
+ outConn, err = utils.ConnectKCPHost(parentAddress, s.cfg.KCP)
} else {
- outConn, err = utils.ConnectHost(s.Resolve(*s.cfg.Parent), *s.cfg.Timeout)
+ outConn, err = utils.ConnectHost(parentAddress, *s.cfg.Timeout)
}
if err != nil {
err = fmt.Errorf("connect fail,%s", err)
@@ -462,44 +490,6 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string, handshake
Password: *s.cfg.ParentKey,
})
}
- if !handshake {
- return
- }
- var buf = make([]byte, 1024)
- //var n int
- outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = outConn.Write(methodBytes)
- outConn.SetDeadline(time.Time{})
- if err != nil {
- err = fmt.Errorf("write method fail,%s", err)
- return
- }
- outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = outConn.Read(buf)
- outConn.SetDeadline(time.Time{})
- if err != nil {
- err = fmt.Errorf("read method reply fail,%s", err)
- return
- }
- //resp := buf[:n]
- //s.log.Printf("resp:%v", resp)
- outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = outConn.Write(reqBytes)
- outConn.SetDeadline(time.Time{})
- if err != nil {
- err = fmt.Errorf("write req detail fail,%s", err)
- return
- }
- outConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = outConn.Read(buf)
- outConn.SetDeadline(time.Time{})
- if err != nil {
- err = fmt.Errorf("read req reply fail,%s", err)
- return
- }
- //result := buf[:n]
- //s.log.Printf("result:%v", result)
-
case "ssh":
maxTryCount := 1
tryCount := 0
@@ -515,17 +505,17 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string, handshake
}
wait <- true
}()
- outConn, err = s.sshClient.Dial("tcp", host)
+ outConn, err = s.sshClient.Dial("tcp", serverConn.Target())
}()
select {
case <-wait:
case <-time.After(time.Millisecond * time.Duration(*s.cfg.Timeout) * 2):
- err = fmt.Errorf("ssh dial %s timeout", host)
+ err = fmt.Errorf("ssh dial %s timeout", serverConn.Target())
s.sshClient.Close()
}
if err != nil {
s.log.Printf("connect ssh fail, ERR: %s, retrying...", err)
- e := s.ConnectSSH()
+ e := s.ConnectSSH(parentAddress)
if e == nil {
tryCount++
time.Sleep(time.Second * 3)
@@ -538,7 +528,7 @@ func (s *Socks) getOutConn(methodBytes, reqBytes []byte, host string, handshake
return
}
-func (s *Socks) ConnectSSH() (err error) {
+func (s *Socks) ConnectSSH(lbAddr string) (err error) {
select {
case s.lockChn <- true:
default:
@@ -556,7 +546,7 @@ func (s *Socks) ConnectSSH() (err error) {
if s.sshClient != nil {
s.sshClient.Close()
}
- s.sshClient, err = ssh.Dial("tcp", s.Resolve(*s.cfg.Parent), &config)
+ s.sshClient, err = ssh.Dial("tcp", s.Resolve(lbAddr), &config)
<-s.lockChn
return
}
@@ -585,6 +575,27 @@ func (s *Socks) InitBasicAuth() (err error) {
}
return
}
+func (s *Socks) InitLB() {
+ configs := lb.BackendsConfig{}
+ for _, addr := range *s.cfg.Parent {
+ _addrInfo := strings.Split(addr, "@")
+ _addr := _addrInfo[0]
+ weight := 1
+ if len(_addrInfo) == 2 {
+ weight, _ = strconv.Atoi(_addrInfo[1])
+ }
+ configs = append(configs, &lb.BackendConfig{
+ Address: _addr,
+ Weight: weight,
+ ActiveAfter: 1,
+ InactiveAfter: 2,
+ Timeout: time.Duration(*s.cfg.LoadBalanceTimeout) * time.Millisecond,
+ RetryTime: time.Duration(*s.cfg.LoadBalanceRetryTime) * time.Millisecond,
+ })
+ }
+ LB := lb.NewGroup(utils.LBMethod(*s.cfg.LoadBalanceMethod), configs, &s.domainResolver, s.log, *s.cfg.Debug)
+ s.lb = &LB
+}
func (s *Socks) IsBasicAuth() bool {
return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 || *s.cfg.AuthURL != ""
}
@@ -602,7 +613,7 @@ func (s *Socks) IsDeadLoop(inLocalAddr string, host string) bool {
if *s.cfg.DNSAddress != "" {
outIPs = []net.IP{net.ParseIP(s.Resolve(outDomain))}
} else {
- outIPs, err = utils.MyLookupIP(outDomain)
+ outIPs, err = utils.LookupIP(outDomain)
}
if err == nil {
for _, ip := range outIPs {
@@ -638,3 +649,37 @@ func (s *Socks) Resolve(address string) string {
}
return ip
}
+func (s *Socks) GetDirectConn(address string, localAddr string) (conn net.Conn, err error) {
+ if !*s.cfg.BindListen {
+ return utils.ConnectHost(address, *s.cfg.Timeout)
+ }
+ ip, _, _ := net.SplitHostPort(localAddr)
+ if utils.IsInternalIP(ip, false) {
+ return utils.ConnectHost(address, *s.cfg.Timeout)
+ }
+ local, _ := net.ResolveTCPAddr("tcp", ip+":0")
+ d := net.Dialer{
+ Timeout: time.Millisecond * time.Duration(*s.cfg.Timeout),
+ LocalAddr: local,
+ }
+ conn, err = d.Dial("tcp", address)
+ return
+}
+func (s *Socks) HandshakeSocksParent(outconn *net.Conn, network, dstAddr string, auth socks.Auth, fromSS bool) (client *socks.ClientConn, err error) {
+ if *s.cfg.ParentAuth != "" {
+ a := strings.Split(*s.cfg.ParentAuth, ":")
+ if len(a) != 2 {
+ err = fmt.Errorf("parent auth data format error")
+ return
+ }
+ client = socks.NewClientConn(outconn, network, dstAddr, time.Millisecond*time.Duration(*s.cfg.Timeout), &socks.Auth{User: a[0], Password: a[1]}, nil)
+ } else {
+ if !fromSS && !s.IsBasicAuth() && auth.Password != "" && auth.User != "" {
+ client = socks.NewClientConn(outconn, network, dstAddr, time.Millisecond*time.Duration(*s.cfg.Timeout), &auth, nil)
+ } else {
+ client = socks.NewClientConn(outconn, network, dstAddr, time.Millisecond*time.Duration(*s.cfg.Timeout), nil, nil)
+ }
+ }
+ err = client.Handshake()
+ return
+}
diff --git a/services/socks/udp.go b/services/socks/udp.go
index ad71a01..5d9ebd1 100644
--- a/services/socks/udp.go
+++ b/services/socks/udp.go
@@ -44,7 +44,7 @@ func (s *Socks) LocalUDPKey() (key []byte) {
}
return
}
-func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, request socks.Request) {
+func (s *Socks) proxyUDP(inConn *net.Conn, serverConn *socks.ServerConn) {
defer func() {
if e := recover(); e != nil {
s.log.Printf("udp local->out io copy crashed:\n%s\n%s", e, string(debug.Stack()))
@@ -54,23 +54,12 @@ func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
utils.CloseConn(inConn)
return
}
- srcIP, _, _ := net.SplitHostPort((*inConn).RemoteAddr().String())
inconnRemoteAddr := (*inConn).RemoteAddr().String()
localAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
- udpListener, err := net.ListenUDP("udp", localAddr)
- if err != nil {
- (*inConn).Close()
- udpListener.Close()
- s.log.Printf("udp bind fail , %s", err)
- return
- }
- host, _, _ := net.SplitHostPort((*inConn).LocalAddr().String())
- _, port, _ := net.SplitHostPort(udpListener.LocalAddr().String())
- if len(*s.cfg.LocalIPS) > 0 {
- host = (*s.cfg.LocalIPS)[0]
- }
- s.log.Printf("proxy udp on %s , for %s", net.JoinHostPort(host, port), inconnRemoteAddr)
- request.UDPReply(socks.REP_SUCCESS, net.JoinHostPort(host, port))
+ udpListener := serverConn.UDPConnListener
+ srcIP, _, _ := net.SplitHostPort((*inConn).RemoteAddr().String())
+ s.log.Printf("proxy udp on %s , for %s", udpListener.LocalAddr(), inconnRemoteAddr)
+
s.userConns.Set(inconnRemoteAddr, inConn)
var (
outUDPConn *net.UDPConn
@@ -134,15 +123,15 @@ func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
}
}()
useProxy := true
- if *s.cfg.Parent != "" {
- dstHost, _, _ := net.SplitHostPort(request.Addr())
- if utils.IsIternalIP(dstHost, *s.cfg.Always) {
+ if len(*s.cfg.Parent) > 0 {
+ dstHost, _, _ := net.SplitHostPort(serverConn.Target())
+ if utils.IsInternalIP(dstHost, *s.cfg.Always) {
useProxy = false
} else {
var isInMap bool
- useProxy, isInMap, _, _ = s.checker.IsBlocked(request.Addr())
+ useProxy, isInMap, _, _ = s.checker.IsBlocked(serverConn.Target())
if !isInMap {
- s.checker.Add(request.Addr(), s.Resolve(request.Addr()))
+ s.checker.Add(serverConn.Target(), s.Resolve(serverConn.Target()))
}
}
} else {
@@ -150,13 +139,17 @@ func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
}
if useProxy {
//parent proxy
- outconn, err := s.getOutConn(nil, nil, "", false)
+ lbAddr := s.lb.Select((*inConn).RemoteAddr().String(), *s.cfg.LoadBalanceOnlyHA)
+ outconn, err := s.GetParentConn(lbAddr, serverConn)
+ //outconn, err := s.GetParentConn(nil, nil, "", false)
if err != nil {
clean("connnect fail", fmt.Sprintf("%s", err))
return
}
- client := socks.NewClientConn(&outconn, "udp", request.Addr(), time.Millisecond*time.Duration(*s.cfg.Timeout), nil, nil)
- if err = client.Handshake(); err != nil {
+
+ client, err := s.HandshakeSocksParent(&outconn, "udp", serverConn.Target(), serverConn.AuthData(), false)
+
+ if err != nil {
clean("handshake fail", fmt.Sprintf("%s", err))
return
}
@@ -179,7 +172,7 @@ func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
//s.log.Printf("parent udp address %s", client.UDPAddr)
destAddr, _ = net.ResolveUDPAddr("udp", client.UDPAddr)
}
- s.log.Printf("use proxy %v : udp %s", useProxy, request.Addr())
+ s.log.Printf("use proxy %v : udp %s", useProxy, serverConn.Target())
//relay
for {
buf := utils.LeakyBuffer.Get()
@@ -245,6 +238,7 @@ func (s *Socks) proxyUDP(inConn *net.Conn, methodReq socks.MethodsRequest, reque
}
continue
}
+
//var dlen = n
if useProxy {
//forward to local
diff --git a/services/sps/socksudp.go b/services/sps/socksudp.go
index e59c7d9..cc52f38 100644
--- a/services/sps/socksudp.go
+++ b/services/sps/socksudp.go
@@ -1,50 +1,17 @@
package sps
import (
- "crypto/md5"
"fmt"
"net"
"runtime/debug"
- "strconv"
"strings"
"time"
"github.com/snail007/goproxy/utils"
goaes "github.com/snail007/goproxy/utils/aes"
- "github.com/snail007/goproxy/utils/conncrypt"
"github.com/snail007/goproxy/utils/socks"
)
-func (s *SPS) ParentUDPKey() (key []byte) {
- switch *s.cfg.ParentType {
- case "tcp":
- if *s.cfg.ParentKey != "" {
- v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.ParentKey)))
- return []byte(v)[:24]
- }
- case "tls":
- return s.cfg.KeyBytes[:24]
- case "kcp":
- v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.KCP.Key)))
- return []byte(v)[:24]
- }
- return
-}
-func (s *SPS) LocalUDPKey() (key []byte) {
- switch *s.cfg.LocalType {
- case "tcp":
- if *s.cfg.LocalKey != "" {
- v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.LocalKey)))
- return []byte(v)[:24]
- }
- case "tls":
- return s.cfg.KeyBytes[:24]
- case "kcp":
- v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.KCP.Key)))
- return []byte(v)[:24]
- }
- return
-}
func (s *SPS) proxyUDP(inConn *net.Conn, serverConn *socks.ServerConn) {
defer func() {
if e := recover(); e != nil {
@@ -123,41 +90,19 @@ func (s *SPS) proxyUDP(inConn *net.Conn, serverConn *socks.ServerConn) {
}
}()
//parent proxy
- outconn, err := s.outPool.Get()
+ lbAddr := s.lb.Select((*inConn).RemoteAddr().String(), *s.cfg.LoadBalanceOnlyHA)
+
+ outconn, err := s.GetParentConn(lbAddr)
//outconn, err := s.GetParentConn(nil, nil, "", false)
if err != nil {
clean("connnect fail", fmt.Sprintf("%s", err))
return
}
- if *s.cfg.ParentCompress {
- outconn = utils.NewCompConn(outconn)
- }
- if *s.cfg.ParentKey != "" {
- outconn = conncrypt.New(outconn, &conncrypt.Config{
- Password: *s.cfg.ParentKey,
- })
- }
-
s.log.Printf("connect %s for udp", serverConn.Target())
//socks client
- var client *socks.ClientConn
- auth := serverConn.AuthData()
- if *s.cfg.ParentAuth != "" {
- a := strings.Split(*s.cfg.ParentAuth, ":")
- if len(a) != 2 {
- err = fmt.Errorf("parent auth data format error")
- return
- }
- client = socks.NewClientConn(&outconn, "udp", serverConn.Target(), time.Millisecond*time.Duration(*s.cfg.Timeout), &socks.Auth{User: a[0], Password: a[1]}, nil)
- } else {
- if !s.IsBasicAuth() && auth.Password != "" && auth.User != "" {
- client = socks.NewClientConn(&outconn, "udp", serverConn.Target(), time.Millisecond*time.Duration(*s.cfg.Timeout), &auth, nil)
- } else {
- client = socks.NewClientConn(&outconn, "udp", serverConn.Target(), time.Millisecond*time.Duration(*s.cfg.Timeout), nil, nil)
- }
- }
- if err = client.Handshake(); err != nil {
+ client, err := s.HandshakeSocksParent(&outconn, "udp", serverConn.Target(), serverConn.AuthData(), false)
+ if err != nil {
clean("handshake fail", fmt.Sprintf("%s", err))
return
}
@@ -181,9 +126,9 @@ func (s *SPS) proxyUDP(inConn *net.Conn, serverConn *socks.ServerConn) {
//s.log.Printf("parent udp address %s", client.UDPAddr)
destAddr, _ = net.ResolveUDPAddr("udp", client.UDPAddr)
//relay
- buf := utils.LeakyBuffer.Get()
- defer utils.LeakyBuffer.Put(buf)
for {
+ buf := utils.LeakyBuffer.Get()
+ defer utils.LeakyBuffer.Put(buf)
n, srcAddr, err := udpListener.ReadFromUDP(buf)
if err != nil {
s.log.Printf("udp listener read fail, %s", err.Error())
@@ -208,76 +153,42 @@ func (s *SPS) proxyUDP(inConn *net.Conn, serverConn *socks.ServerConn) {
} else {
err = p.Parse(buf[:n])
}
+ //err = p.Parse(buf[:n])
if err != nil {
s.log.Printf("udp listener parse packet fail, %s", err.Error())
continue
}
-
- port, _ := strconv.Atoi(p.Port())
-
if v, ok := s.udpRelatedPacketConns.Get(srcAddr.String()); !ok {
- if destAddr == nil {
- destAddr = &net.UDPAddr{IP: net.ParseIP(p.Host()), Port: port}
- }
outUDPConn, err = net.DialUDP("udp", localAddr, destAddr)
if err != nil {
s.log.Printf("create out udp conn fail , %s , from : %s", err, srcAddr)
continue
}
s.udpRelatedPacketConns.Set(srcAddr.String(), outUDPConn)
- go func() {
- defer func() {
- if e := recover(); e != nil {
- s.log.Printf("udp out->local io copy crashed:\n%s\n%s", e, string(debug.Stack()))
- }
- }()
- defer s.udpRelatedPacketConns.Remove(srcAddr.String())
- //out->local io copy
- buf := utils.LeakyBuffer.Get()
- defer utils.LeakyBuffer.Put(buf)
- for {
- outUDPConn.SetReadDeadline(time.Now().Add(time.Second * 5))
- n, err := outUDPConn.Read(buf)
- outUDPConn.SetReadDeadline(time.Time{})
+ utils.UDPCopy(udpListener, outUDPConn, srcAddr, 0, func(data []byte) []byte {
+ //forward to local
+ var v []byte
+ //convert parent data to raw
+ if len(s.udpParentKey) > 0 {
+ v, err = goaes.Decrypt(s.udpParentKey, data)
if err != nil {
- s.log.Printf("read out udp data fail , %s , from : %s", err, srcAddr)
- if isClosedErr(err) {
- return
- }
- continue
- }
- //var dlen = n
- //forward to local
- var v []byte
- //convert parent data to raw
- if len(s.udpParentKey) > 0 {
- v, err = goaes.Decrypt(s.udpParentKey, buf[:n])
- if err != nil {
- s.log.Printf("udp outconn parse packet fail, %s", err.Error())
- continue
- }
- } else {
- v = buf[:n]
- }
- //now v is raw, try convert v to local
- if len(s.udpLocalKey) > 0 {
- v, _ = goaes.Encrypt(s.udpLocalKey, v)
- }
- _, err = udpListener.WriteTo(v, srcAddr)
- // _, err = udpListener.WriteTo(buf[:n], srcAddr)
-
- if err != nil {
- s.udpRelatedPacketConns.Remove(srcAddr.String())
- s.log.Printf("write out data to local fail , %s , from : %s", err, srcAddr)
- if isClosedErr(err) {
- return
- }
- continue
- } else {
- //s.log.Printf("send udp data to local success , len %d, for : %s", dlen, srcAddr)
+ s.log.Printf("udp outconn parse packet fail, %s", err.Error())
+ return []byte{}
}
+ } else {
+ v = data
}
- }()
+ //now v is raw, try convert v to local
+ if len(s.udpLocalKey) > 0 {
+ v, _ = goaes.Encrypt(s.udpLocalKey, v)
+ }
+ return v
+ }, func(err interface{}) {
+ s.udpRelatedPacketConns.Remove(srcAddr.String())
+ if err != nil {
+ s.log.Printf("udp out->local io copy crashed:\n%s\n%s", err, string(debug.Stack()))
+ }
+ })
} else {
outUDPConn = v.(*net.UDPConn)
}
diff --git a/services/sps/sps.go b/services/sps/sps.go
index 343f69b..da1bad9 100644
--- a/services/sps/sps.go
+++ b/services/sps/sps.go
@@ -2,6 +2,8 @@ package sps
import (
"bytes"
+ "crypto/md5"
+ "crypto/tls"
"encoding/base64"
"errors"
"fmt"
@@ -17,66 +19,92 @@ import (
"github.com/snail007/goproxy/services/kcpcfg"
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/conncrypt"
+ "github.com/snail007/goproxy/utils/datasize"
+ "github.com/snail007/goproxy/utils/dnsx"
+ "github.com/snail007/goproxy/utils/iolimiter"
+ "github.com/snail007/goproxy/utils/lb"
+ "github.com/snail007/goproxy/utils/mapx"
"github.com/snail007/goproxy/utils/sni"
"github.com/snail007/goproxy/utils/socks"
+ "github.com/snail007/goproxy/utils/ss"
)
type SPSArgs struct {
- Parent *string
- CertFile *string
- KeyFile *string
- CaCertFile *string
- CaCertBytes []byte
- CertBytes []byte
- KeyBytes []byte
- Local *string
- ParentType *string
- LocalType *string
- Timeout *int
- KCP kcpcfg.KCPConfigArgs
- ParentServiceType *string
- DNSAddress *string
- DNSTTL *int
- AuthFile *string
- Auth *[]string
- AuthURL *string
- AuthURLOkCode *int
- AuthURLTimeout *int
- AuthURLRetry *int
- LocalIPS *[]string
- ParentAuth *string
- LocalKey *string
- ParentKey *string
- LocalCompress *bool
- ParentCompress *bool
- DisableHTTP *bool
- DisableSocks5 *bool
+ Parent *[]string
+ CertFile *string
+ KeyFile *string
+ CaCertFile *string
+ CaCertBytes []byte
+ CertBytes []byte
+ KeyBytes []byte
+ Local *string
+ ParentType *string
+ LocalType *string
+ Timeout *int
+ KCP kcpcfg.KCPConfigArgs
+ ParentServiceType *string
+ DNSAddress *string
+ DNSTTL *int
+ AuthFile *string
+ Auth *[]string
+ AuthURL *string
+ AuthURLOkCode *int
+ AuthURLTimeout *int
+ AuthURLRetry *int
+ LocalIPS *[]string
+ ParentAuth *string
+ LocalKey *string
+ ParentKey *string
+ LocalCompress *bool
+ ParentCompress *bool
+ SSMethod *string
+ SSKey *string
+ ParentSSMethod *string
+ ParentSSKey *string
+ DisableHTTP *bool
+ DisableSocks5 *bool
+ DisableSS *bool
+ LoadBalanceMethod *string
+ LoadBalanceTimeout *int
+ LoadBalanceRetryTime *int
+ LoadBalanceHashTarget *bool
+ LoadBalanceOnlyHA *bool
+
+ RateLimit *string
+ RateLimitBytes float64
+ Debug *bool
}
type SPS struct {
- outPool utils.OutConn
cfg SPSArgs
- domainResolver utils.DomainResolver
+ domainResolver dnsx.DomainResolver
basicAuth utils.BasicAuth
serverChannels []*utils.ServerChannel
- userConns utils.ConcurrentMap
+ userConns mapx.ConcurrentMap
log *logger.Logger
- udpRelatedPacketConns utils.ConcurrentMap
+ localCipher *ss.Cipher
+ parentCipher *ss.Cipher
+ udpRelatedPacketConns mapx.ConcurrentMap
+ lb *lb.Group
udpLocalKey []byte
udpParentKey []byte
}
func NewSPS() services.Service {
return &SPS{
- outPool: utils.OutConn{},
cfg: SPSArgs{},
basicAuth: utils.BasicAuth{},
serverChannels: []*utils.ServerChannel{},
- userConns: utils.NewConcurrentMap(),
- udpRelatedPacketConns: utils.NewConcurrentMap(),
+ userConns: mapx.NewConcurrentMap(),
+ udpRelatedPacketConns: mapx.NewConcurrentMap(),
}
}
func (s *SPS) CheckArgs() (err error) {
- if *s.cfg.Parent == "" {
+
+ if len(*s.cfg.Parent) == 1 && (*s.cfg.Parent)[0] == "" {
+ (*s.cfg.Parent) = []string{}
+ }
+
+ if len(*s.cfg.Parent) == 0 {
err = fmt.Errorf("parent required for %s %s", *s.cfg.LocalType, *s.cfg.Local)
return
}
@@ -84,6 +112,10 @@ func (s *SPS) CheckArgs() (err error) {
err = fmt.Errorf("parent type unkown,use -T ")
return
}
+ if *s.cfg.ParentType == "ss" && (*s.cfg.ParentSSKey == "" || *s.cfg.ParentSSMethod == "") {
+ err = fmt.Errorf("ss parent need a ss key, set it by : -J ")
+ return
+ }
if *s.cfg.ParentType == "tls" || *s.cfg.LocalType == "tls" {
s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
if err != nil {
@@ -97,31 +129,45 @@ func (s *SPS) CheckArgs() (err error) {
}
}
}
+ if *s.cfg.RateLimit != "0" && *s.cfg.RateLimit != "" {
+ var size uint64
+ size, err = datasize.Parse(*s.cfg.RateLimit)
+ if err != nil {
+ err = fmt.Errorf("parse rate limit size error,ERR:%s", err)
+ return
+ }
+ s.cfg.RateLimitBytes = float64(size)
+ }
s.udpLocalKey = s.LocalUDPKey()
s.udpParentKey = s.ParentUDPKey()
return
}
func (s *SPS) InitService() (err error) {
+
if *s.cfg.DNSAddress != "" {
- (*s).domainResolver = utils.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log)
+ (*s).domainResolver = dnsx.NewDomainResolver(*s.cfg.DNSAddress, *s.cfg.DNSTTL, s.log)
}
- s.InitOutConnPool()
+
+ if len(*s.cfg.Parent) > 0 {
+ s.InitLB()
+ }
+
err = s.InitBasicAuth()
- return
-}
-func (s *SPS) InitOutConnPool() {
- if *s.cfg.ParentType == "tls" || *s.cfg.ParentType == "tcp" || *s.cfg.ParentType == "kcp" {
- //dur int, isTLS bool, certBytes, keyBytes []byte,
- //parent string, timeout int, InitialCap int, MaxCap int
- s.outPool = utils.NewOutConn(
- 0,
- *s.cfg.ParentType,
- s.cfg.KCP,
- s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes,
- *s.cfg.Parent,
- *s.cfg.Timeout,
- )
+ if *s.cfg.SSMethod != "" && *s.cfg.SSKey != "" {
+ s.localCipher, err = ss.NewCipher(*s.cfg.SSMethod, *s.cfg.SSKey)
+ if err != nil {
+ s.log.Printf("error generating cipher : %s", err)
+ return
+ }
}
+ if *s.cfg.ParentServiceType == "ss" {
+ s.parentCipher, err = ss.NewCipher(*s.cfg.ParentSSMethod, *s.cfg.ParentSSKey)
+ if err != nil {
+ s.log.Printf("error generating cipher : %s", err)
+ return
+ }
+ }
+ return
}
func (s *SPS) StopService() {
@@ -130,8 +176,21 @@ func (s *SPS) StopService() {
if e != nil {
s.log.Printf("stop sps service crashed,%s", e)
} else {
- s.log.Printf("service sps stopped")
+ s.log.Printf("service sps stoped")
}
+ s.basicAuth = utils.BasicAuth{}
+ s.cfg = SPSArgs{}
+ s.domainResolver = dnsx.DomainResolver{}
+ s.lb = nil
+ s.localCipher = nil
+ s.log = nil
+ s.parentCipher = nil
+ s.serverChannels = nil
+ s.udpLocalKey = nil
+ s.udpParentKey = nil
+ s.udpRelatedPacketConns = nil
+ s.userConns = nil
+ s = nil
}()
for _, sc := range s.serverChannels {
if sc.Listener != nil && *sc.Listener != nil {
@@ -149,6 +208,12 @@ func (s *SPS) StopService() {
(*(*c.(**net.Conn))).Close()
}
}
+ if s.lb != nil {
+ s.lb.Stop()
+ }
+ for _, c := range s.udpRelatedPacketConns.Items() {
+ (*c.(*net.UDPConn)).Close()
+ }
}
func (s *SPS) Start(args interface{}, log *logger.Logger) (err error) {
s.log = log
@@ -159,24 +224,30 @@ func (s *SPS) Start(args interface{}, log *logger.Logger) (err error) {
if err = s.InitService(); err != nil {
return
}
- s.log.Printf("use %s %s parent %s", *s.cfg.ParentType, *s.cfg.ParentServiceType, *s.cfg.Parent)
+
+ s.log.Printf("use %s %s parent %v [ %s ]", *s.cfg.ParentType, *s.cfg.ParentServiceType, *s.cfg.Parent, strings.ToUpper(*s.cfg.LoadBalanceMethod))
for _, addr := range strings.Split(*s.cfg.Local, ",") {
if addr != "" {
host, port, _ := net.SplitHostPort(addr)
p, _ := strconv.Atoi(port)
sc := utils.NewServerChannel(host, p, s.log)
+ s.serverChannels = append(s.serverChannels, &sc)
if *s.cfg.LocalType == "tcp" {
err = sc.ListenTCP(s.callback)
} else if *s.cfg.LocalType == "tls" {
err = sc.ListenTls(s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes, s.callback)
- } else if *s.cfg.LocalType == "kcp" {
+ } else if *s.cfg.LocalType == "tcp" {
err = sc.ListenKCP(s.cfg.KCP, s.callback, s.log)
}
+ if *s.cfg.ParentServiceType == "socks" {
+ err = s.RunSSUDP(addr)
+ } else {
+ s.log.Println("warn : udp only for socks parent ")
+ }
if err != nil {
return
}
- s.log.Printf("%s http(s)+socks proxy on %s", *s.cfg.LocalType, (*sc.Listener).Addr())
- s.serverChannels = append(s.serverChannels, &sc)
+ s.log.Printf("%s http(s)+socks+ss proxy on %s", *s.cfg.LocalType, (*sc.Listener).Addr())
}
}
return
@@ -200,22 +271,23 @@ func (s *SPS) callback(inConn net.Conn) {
})
}
var err error
+ lbAddr := ""
switch *s.cfg.ParentType {
case "kcp":
fallthrough
case "tcp":
fallthrough
case "tls":
- err = s.OutToTCP(&inConn)
+ lbAddr, err = s.OutToTCP(&inConn)
default:
err = fmt.Errorf("unkown parent type %s", *s.cfg.ParentType)
}
if err != nil {
- s.log.Printf("connect to %s parent %s fail, ERR:%s from %s", *s.cfg.ParentType, *s.cfg.Parent, err, inConn.RemoteAddr())
+ s.log.Printf("connect to %s parent %s fail, ERR:%s from %s", *s.cfg.ParentType, lbAddr, err, inConn.RemoteAddr())
utils.CloseConn(&inConn)
}
}
-func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
+func (s *SPS) OutToTCP(inConn *net.Conn) (lbAddr string, err error) {
enableUDP := *s.cfg.ParentServiceType == "socks"
udpIP, _, _ := net.SplitHostPort((*inConn).LocalAddr().String())
if len(*s.cfg.LocalIPS) > 0 {
@@ -243,9 +315,9 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
isSNI, _ := sni.ServerNameFromBytes(h)
*inConn = bInConn
address := ""
- var auth socks.Auth
+ var auth = socks.Auth{}
var forwardBytes []byte
- //fmt.Printf("%v", h)
+ //fmt.Printf("%v", header)
if utils.IsSocks5(h) {
if *s.cfg.DisableSocks5 {
(*inConn).Close()
@@ -306,6 +378,25 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
auth = socks.Auth{User: userpassA[0], Password: userpassA[1]}
}
}
+ } else {
+ //ss
+ if *s.cfg.DisableSS {
+ (*inConn).Close()
+ return
+ }
+ (*inConn).SetDeadline(time.Now().Add(time.Second * 5))
+ ssConn := ss.NewConn(*inConn, s.localCipher.Copy())
+ address, err = ss.GetRequest(ssConn)
+ (*inConn).SetDeadline(time.Time{})
+ if err != nil {
+ return
+ }
+ // ensure the host does not contain some illegal characters, NUL may panic on Win32
+ if strings.ContainsRune(address, 0x00) {
+ err = errors.New("invalid domain name")
+ return
+ }
+ *inConn = ssConn
}
if err != nil || address == "" {
s.log.Printf("unknown request from: %s,%s", (*inConn).RemoteAddr(), string(h))
@@ -316,22 +407,20 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
}
//connect to parent
var outConn net.Conn
- outConn, err = s.outPool.Get()
+ selectAddr := (*inConn).RemoteAddr().String()
+ if utils.LBMethod(*s.cfg.LoadBalanceMethod) == lb.SELECT_HASH && *s.cfg.LoadBalanceHashTarget {
+ selectAddr = address
+ }
+ lbAddr = s.lb.Select(selectAddr, *s.cfg.LoadBalanceOnlyHA)
+ //lbAddr = s.lb.Select((*inConn).RemoteAddr().String())
+ outConn, err = s.GetParentConn(lbAddr)
if err != nil {
- s.log.Printf("connect to %s , err:%s", *s.cfg.Parent, err)
+ s.log.Printf("connect to %s , err:%s", lbAddr, err)
utils.CloseConn(inConn)
return
}
- if *s.cfg.ParentCompress {
- outConn = utils.NewCompConn(outConn)
- }
- if *s.cfg.ParentKey != "" {
- outConn = conncrypt.New(outConn, &conncrypt.Config{
- Password: *s.cfg.ParentKey,
- })
- }
- if *s.cfg.ParentAuth != "" || s.IsBasicAuth() {
+ if *s.cfg.ParentAuth != "" || *s.cfg.ParentSSKey != "" || s.IsBasicAuth() {
forwardBytes = utils.RemoveProxyHeaders(forwardBytes)
}
@@ -345,7 +434,10 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
isHTTPS = true
pb.Write([]byte(fmt.Sprintf("CONNECT %s HTTP/1.1\r\n", address)))
}
+ pb.WriteString(fmt.Sprintf("Host: %s\r\n", address))
+ pb.WriteString(fmt.Sprintf("Proxy-Host: %s\r\n", address))
pb.WriteString("Proxy-Connection: Keep-Alive\r\n")
+ pb.WriteString("Connection: Keep-Alive\r\n")
u := ""
if *s.cfg.ParentAuth != "" {
@@ -377,7 +469,7 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
_, err = outConn.Write(pb.Bytes())
outConn.SetDeadline(time.Time{})
if err != nil {
- s.log.Printf("write CONNECT to %s , err:%s", *s.cfg.Parent, err)
+ s.log.Printf("write CONNECT to %s , err:%s", lbAddr, err)
utils.CloseConn(inConn)
utils.CloseConn(&outConn)
return
@@ -389,7 +481,7 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
_, err = outConn.Read(reply)
outConn.SetDeadline(time.Time{})
if err != nil {
- s.log.Printf("read reply from %s , err:%s", *s.cfg.Parent, err)
+ s.log.Printf("read reply from %s , err:%s", lbAddr, err)
utils.CloseConn(inConn)
utils.CloseConn(&outConn)
return
@@ -398,23 +490,24 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
}
} else if *s.cfg.ParentServiceType == "socks" {
s.log.Printf("connect %s", address)
+
//socks client
- var clientConn *socks.ClientConn
- if *s.cfg.ParentAuth != "" {
- a := strings.Split(*s.cfg.ParentAuth, ":")
- if len(a) != 2 {
- err = fmt.Errorf("parent auth data format error")
- return
- }
- clientConn = socks.NewClientConn(&outConn, "tcp", address, time.Millisecond*time.Duration(*s.cfg.Timeout), &socks.Auth{User: a[0], Password: a[1]}, nil)
- } else {
- if !s.IsBasicAuth() && auth.Password != "" && auth.User != "" {
- clientConn = socks.NewClientConn(&outConn, "tcp", address, time.Millisecond*time.Duration(*s.cfg.Timeout), &auth, nil)
- } else {
- clientConn = socks.NewClientConn(&outConn, "tcp", address, time.Millisecond*time.Duration(*s.cfg.Timeout), nil, nil)
- }
+ _, err = s.HandshakeSocksParent(&outConn, "tcp", address, auth, false)
+ if err != nil {
+ s.log.Printf("handshake fail, %s", err)
+ return
}
- if err = clientConn.Handshake(); err != nil {
+
+ } else if *s.cfg.ParentServiceType == "ss" {
+ ra, e := ss.RawAddr(address)
+ if e != nil {
+ err = fmt.Errorf("build ss raw addr fail, err: %s", e)
+ return
+ }
+
+ outConn, err = ss.DialWithRawAddr(&outConn, ra, "", s.parentCipher.Copy())
+ if err != nil {
+ err = fmt.Errorf("dial ss parent fail, err : %s", err)
return
}
}
@@ -424,18 +517,27 @@ func (s *SPS) OutToTCP(inConn *net.Conn) (err error) {
outConn.Write(forwardBytes)
}
+ if s.cfg.RateLimitBytes > 0 {
+ outConn = iolimiter.NewReaderConn(outConn, s.cfg.RateLimitBytes)
+ }
+
//bind
inAddr := (*inConn).RemoteAddr().String()
outAddr := outConn.RemoteAddr().String()
utils.IoBind((*inConn), outConn, func(err interface{}) {
s.log.Printf("conn %s - %s released [%s]", inAddr, outAddr, address)
s.userConns.Remove(inAddr)
+ s.lb.DecreaseConns(lbAddr)
}, s.log)
s.log.Printf("conn %s - %s connected [%s]", inAddr, outAddr, address)
+
+ s.lb.IncreasConns(lbAddr)
+
if c, ok := s.userConns.Get(inAddr); ok {
(*c.(*net.Conn)).Close()
}
s.userConns.Set(inAddr, inConn)
+
return
}
func (s *SPS) InitBasicAuth() (err error) {
@@ -463,6 +565,27 @@ func (s *SPS) InitBasicAuth() (err error) {
}
return
}
+func (s *SPS) InitLB() {
+ configs := lb.BackendsConfig{}
+ for _, addr := range *s.cfg.Parent {
+ _addrInfo := strings.Split(addr, "@")
+ _addr := _addrInfo[0]
+ weight := 1
+ if len(_addrInfo) == 2 {
+ weight, _ = strconv.Atoi(_addrInfo[1])
+ }
+ configs = append(configs, &lb.BackendConfig{
+ Address: _addr,
+ Weight: weight,
+ ActiveAfter: 1,
+ InactiveAfter: 2,
+ Timeout: time.Duration(*s.cfg.LoadBalanceTimeout) * time.Millisecond,
+ RetryTime: time.Duration(*s.cfg.LoadBalanceRetryTime) * time.Millisecond,
+ })
+ }
+ LB := lb.NewGroup(utils.LBMethod(*s.cfg.LoadBalanceMethod), configs, &s.domainResolver, s.log, *s.cfg.Debug)
+ s.lb = &LB
+}
func (s *SPS) IsBasicAuth() bool {
return *s.cfg.AuthFile != "" || len(*s.cfg.Auth) > 0 || *s.cfg.AuthURL != ""
}
@@ -515,3 +638,75 @@ func (s *SPS) Resolve(address string) string {
}
return ip
}
+func (s *SPS) GetParentConn(address string) (conn net.Conn, err error) {
+ if *s.cfg.ParentType == "tls" {
+ var _conn tls.Conn
+ _conn, err = utils.TlsConnectHost(address, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, s.cfg.CaCertBytes)
+ if err == nil {
+ conn = net.Conn(&_conn)
+ }
+ } else if *s.cfg.ParentType == "kcp" {
+ conn, err = utils.ConnectKCPHost(address, s.cfg.KCP)
+ } else {
+ conn, err = utils.ConnectHost(address, *s.cfg.Timeout)
+ }
+ if err == nil {
+ if *s.cfg.ParentCompress {
+ conn = utils.NewCompConn(conn)
+ }
+ if *s.cfg.ParentKey != "" {
+ conn = conncrypt.New(conn, &conncrypt.Config{
+ Password: *s.cfg.ParentKey,
+ })
+ }
+ }
+ return
+}
+func (s *SPS) HandshakeSocksParent(outconn *net.Conn, network, dstAddr string, auth socks.Auth, fromSS bool) (client *socks.ClientConn, err error) {
+ if *s.cfg.ParentAuth != "" {
+ a := strings.Split(*s.cfg.ParentAuth, ":")
+ if len(a) != 2 {
+ err = fmt.Errorf("parent auth data format error")
+ return
+ }
+ client = socks.NewClientConn(outconn, network, dstAddr, time.Millisecond*time.Duration(*s.cfg.Timeout), &socks.Auth{User: a[0], Password: a[1]}, nil)
+ } else {
+ if !fromSS && !s.IsBasicAuth() && auth.Password != "" && auth.User != "" {
+ client = socks.NewClientConn(outconn, network, dstAddr, time.Millisecond*time.Duration(*s.cfg.Timeout), &auth, nil)
+ } else {
+ client = socks.NewClientConn(outconn, network, dstAddr, time.Millisecond*time.Duration(*s.cfg.Timeout), nil, nil)
+ }
+ }
+ err = client.Handshake()
+ return
+}
+func (s *SPS) ParentUDPKey() (key []byte) {
+ switch *s.cfg.ParentType {
+ case "tcp":
+ if *s.cfg.ParentKey != "" {
+ v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.ParentKey)))
+ return []byte(v)[:24]
+ }
+ case "tls":
+ return s.cfg.KeyBytes[:24]
+ case "kcp":
+ v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.KCP.Key)))
+ return []byte(v)[:24]
+ }
+ return
+}
+func (s *SPS) LocalUDPKey() (key []byte) {
+ switch *s.cfg.LocalType {
+ case "tcp":
+ if *s.cfg.LocalKey != "" {
+ v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.LocalKey)))
+ return []byte(v)[:24]
+ }
+ case "tls":
+ return s.cfg.KeyBytes[:24]
+ case "kcp":
+ v := fmt.Sprintf("%x", md5.Sum([]byte(*s.cfg.KCP.Key)))
+ return []byte(v)[:24]
+ }
+ return
+}
diff --git a/services/sps/ssudp.go b/services/sps/ssudp.go
new file mode 100644
index 0000000..9eccf1e
--- /dev/null
+++ b/services/sps/ssudp.go
@@ -0,0 +1,161 @@
+package sps
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "runtime/debug"
+ "time"
+
+ "github.com/snail007/goproxy/utils"
+ goaes "github.com/snail007/goproxy/utils/aes"
+ "github.com/snail007/goproxy/utils/socks"
+)
+
+func (s *SPS) RunSSUDP(addr string) (err error) {
+ a, _ := net.ResolveUDPAddr("udp", addr)
+ listener, err := net.ListenUDP("udp", a)
+ if err != nil {
+ s.log.Printf("ss udp bind error %s", err)
+ return
+ }
+ s.log.Printf("ss udp on %s", listener.LocalAddr())
+ s.udpRelatedPacketConns.Set(addr, listener)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("udp local->out io copy crashed:\n%s\n%s", e, string(debug.Stack()))
+ }
+ }()
+ for {
+ buf := utils.LeakyBuffer.Get()
+ defer utils.LeakyBuffer.Put(buf)
+ n, srcAddr, err := listener.ReadFrom(buf)
+ if err != nil {
+ s.log.Printf("read from client error %s", err)
+ if utils.IsNetClosedErr(err) {
+ return
+ }
+ continue
+ }
+ var (
+ inconnRemoteAddr = srcAddr.String()
+ outUDPConn *net.UDPConn
+ outconn net.Conn
+ outconnLocalAddr string
+ destAddr *net.UDPAddr
+ clean = func(msg, err string) {
+ raddr := ""
+ if outUDPConn != nil {
+ raddr = outUDPConn.RemoteAddr().String()
+ outUDPConn.Close()
+ }
+ if msg != "" {
+ if raddr != "" {
+ s.log.Printf("%s , %s , %s -> %s", msg, err, inconnRemoteAddr, raddr)
+ } else {
+ s.log.Printf("%s , %s , from : %s", msg, err, inconnRemoteAddr)
+ }
+ }
+ s.userConns.Remove(inconnRemoteAddr)
+ if outconn != nil {
+ outconn.Close()
+ }
+ if outconnLocalAddr != "" {
+ s.userConns.Remove(outconnLocalAddr)
+ }
+ }
+ )
+ defer clean("", "")
+
+ raw := new(bytes.Buffer)
+ raw.Write([]byte{0x00, 0x00, 0x00})
+ raw.Write(s.localCipher.Decrypt(buf[:n]))
+ socksPacket := socks.NewPacketUDP()
+ err = socksPacket.Parse(raw.Bytes())
+ raw = nil
+ if err != nil {
+ s.log.Printf("udp parse error %s", err)
+ return
+ }
+
+ if v, ok := s.udpRelatedPacketConns.Get(inconnRemoteAddr); !ok {
+ //socks client
+ lbAddr := s.lb.Select(inconnRemoteAddr, *s.cfg.LoadBalanceOnlyHA)
+ outconn, err := s.GetParentConn(lbAddr)
+ if err != nil {
+ clean("connnect fail", fmt.Sprintf("%s", err))
+ return
+ }
+
+ client, err := s.HandshakeSocksParent(&outconn, "udp", socksPacket.Addr(), socks.Auth{}, true)
+ if err != nil {
+ clean("handshake fail", fmt.Sprintf("%s", err))
+ return
+ }
+
+ outconnLocalAddr = outconn.LocalAddr().String()
+ s.userConns.Set(outconnLocalAddr, &outconn)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ s.log.Printf("udp related parent tcp conn read crashed:\n%s\n%s", e, string(debug.Stack()))
+ }
+ }()
+ buf := make([]byte, 1)
+ outconn.SetReadDeadline(time.Time{})
+ if _, err := outconn.Read(buf); err != nil {
+ clean("udp parent tcp conn disconnected", fmt.Sprintf("%s", err))
+ }
+ }()
+ destAddr, _ = net.ResolveUDPAddr("udp", client.UDPAddr)
+ localZeroAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
+ outUDPConn, err = net.DialUDP("udp", localZeroAddr, destAddr)
+ if err != nil {
+ s.log.Printf("create out udp conn fail , %s , from : %s", err, srcAddr)
+ return
+ }
+ s.udpRelatedPacketConns.Set(srcAddr.String(), outUDPConn)
+ utils.UDPCopy(listener, outUDPConn, srcAddr, time.Second*5, func(data []byte) []byte {
+ //forward to local
+ var v []byte
+ //convert parent data to raw
+ if len(s.udpParentKey) > 0 {
+ v, err = goaes.Decrypt(s.udpParentKey, data)
+ if err != nil {
+ s.log.Printf("udp outconn parse packet fail, %s", err.Error())
+ return []byte{}
+ }
+ } else {
+ v = data
+ }
+ return s.localCipher.Encrypt(v[3:])
+ }, func(err interface{}) {
+ s.udpRelatedPacketConns.Remove(srcAddr.String())
+ if err != nil {
+ s.log.Printf("udp out->local io copy crashed:\n%s\n%s", err, string(debug.Stack()))
+ }
+ })
+ } else {
+ outUDPConn = v.(*net.UDPConn)
+ }
+ //forward to parent
+ //p is raw, now convert it to parent
+ var v []byte
+ if len(s.udpParentKey) > 0 {
+ v, _ = goaes.Encrypt(s.udpParentKey, socksPacket.Bytes())
+ } else {
+ v = socksPacket.Bytes()
+ }
+ _, err = outUDPConn.Write(v)
+ socksPacket = socks.PacketUDP{}
+ if err != nil {
+ if utils.IsNetClosedErr(err) {
+ return
+ }
+ s.log.Printf("send out udp data fail , %s , from : %s", err, srcAddr)
+ }
+ }
+ }()
+ return
+}
diff --git a/services/tcp/tcp.go b/services/tcp/tcp.go
index 1ca2105..4f4c792 100644
--- a/services/tcp/tcp.go
+++ b/services/tcp/tcp.go
@@ -1,55 +1,66 @@
package tcp
import (
- "bufio"
"crypto/tls"
"fmt"
- "io"
logger "log"
"net"
"runtime/debug"
+ "strings"
"time"
"github.com/snail007/goproxy/services"
"github.com/snail007/goproxy/services/kcpcfg"
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/jumper"
+ "github.com/snail007/goproxy/utils/mapx"
"strconv"
)
type TCPArgs struct {
- Parent *string
- CertFile *string
- KeyFile *string
- CertBytes []byte
- KeyBytes []byte
- Local *string
- ParentType *string
- LocalType *string
- Timeout *int
- KCP kcpcfg.KCPConfigArgs
- Jumper *string
+ Parent *string
+ CertFile *string
+ KeyFile *string
+ CertBytes []byte
+ KeyBytes []byte
+ Local *string
+ ParentType *string
+ LocalType *string
+ Timeout *int
+ CheckParentInterval *int
+ KCP kcpcfg.KCPConfigArgs
+ Jumper *string
+}
+type UDPConnItem struct {
+ conn *net.Conn
+ isActive bool
+ touchtime int64
+ srcAddr *net.UDPAddr
+ localAddr *net.UDPAddr
+ udpConn *net.UDPConn
+ connid string
}
-
type TCP struct {
cfg TCPArgs
sc *utils.ServerChannel
isStop bool
- userConns utils.ConcurrentMap
+ userConns mapx.ConcurrentMap
log *logger.Logger
jumper *jumper.Jumper
+ udpConns mapx.ConcurrentMap
}
func NewTCP() services.Service {
return &TCP{
cfg: TCPArgs{},
isStop: false,
- userConns: utils.NewConcurrentMap(),
+ userConns: mapx.NewConcurrentMap(),
+ udpConns: mapx.NewConcurrentMap(),
}
}
func (s *TCP) CheckArgs() (err error) {
- if *s.cfg.Parent == "" {
+ if len(*s.cfg.Parent) == 0 {
err = fmt.Errorf("parent required for %s %s", *s.cfg.LocalType, *s.cfg.Local)
return
}
@@ -79,7 +90,7 @@ func (s *TCP) CheckArgs() (err error) {
return
}
func (s *TCP) InitService() (err error) {
-
+ s.UDPGCDeamon()
return
}
func (s *TCP) StopService() {
@@ -88,8 +99,14 @@ func (s *TCP) StopService() {
if e != nil {
s.log.Printf("stop tcp service crashed,%s", e)
} else {
- s.log.Printf("service tcp stopped")
+ s.log.Printf("service tcp stoped")
}
+ s.cfg = TCPArgs{}
+ s.jumper = nil
+ s.log = nil
+ s.sc = nil
+ s.userConns = nil
+ s = nil
}()
s.isStop = true
if s.sc.Listener != nil && *s.sc.Listener != nil {
@@ -111,7 +128,7 @@ func (s *TCP) Start(args interface{}, log *logger.Logger) (err error) {
if err = s.InitService(); err != nil {
return
}
- s.log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent)
+ s.log.Printf("use %s parent %v", *s.cfg.ParentType, *s.cfg.Parent)
host, port, _ := net.SplitHostPort(*s.cfg.Local)
p, _ := strconv.Atoi(port)
sc := utils.NewServerChannel(host, p, s.log)
@@ -141,6 +158,7 @@ func (s *TCP) callback(inConn net.Conn) {
}
}()
var err error
+ lbAddr := ""
switch *s.cfg.ParentType {
case "kcp":
fallthrough
@@ -149,16 +167,19 @@ func (s *TCP) callback(inConn net.Conn) {
case "tls":
err = s.OutToTCP(&inConn)
case "udp":
- err = s.OutToUDP(&inConn)
+ s.OutToUDP(&inConn)
default:
err = fmt.Errorf("unkown parent type %s", *s.cfg.ParentType)
}
if err != nil {
- s.log.Printf("connect to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err)
+ if !utils.IsNetClosedErr(err) {
+ s.log.Printf("connect to %s parent %s fail, ERR:%s", *s.cfg.ParentType, lbAddr, err)
+ }
utils.CloseConn(&inConn)
}
}
func (s *TCP) OutToTCP(inConn *net.Conn) (err error) {
+
var outConn net.Conn
outConn, err = s.GetParentConn()
if err != nil {
@@ -182,56 +203,136 @@ func (s *TCP) OutToTCP(inConn *net.Conn) (err error) {
return
}
func (s *TCP) OutToUDP(inConn *net.Conn) (err error) {
- s.log.Printf("conn created , remote : %s ", (*inConn).RemoteAddr())
+ var item *UDPConnItem
+ var body []byte
+ srcAddr := ""
+ defer func() {
+ if item != nil {
+ (*(*item).conn).Close()
+ (*item).udpConn.Close()
+ s.udpConns.Remove(srcAddr)
+ (*inConn).Close()
+ }
+ }()
for {
if s.isStop {
- (*inConn).Close()
return
}
- srcAddr, body, err := utils.ReadUDPPacket(bufio.NewReader(*inConn))
- if err == io.EOF || err == io.ErrUnexpectedEOF {
- //s.log.Printf("connection %s released", srcAddr)
- utils.CloseConn(inConn)
- break
- }
- //log.Debugf("udp packet revecived:%s,%v", srcAddr, body)
- dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Parent)
+ var srcAddr string
+ srcAddr, body, err = utils.ReadUDPPacket(*inConn)
if err != nil {
- s.log.Printf("can't resolve address: %s", err)
- utils.CloseConn(inConn)
- break
+ if strings.Contains(err.Error(), "n != int(") {
+ continue
+ }
+ // if !utils.IsNetDeadlineErr(err) && err != io.EOF && !utils.IsNetClosedErr(err) {
+ // s.log.Printf("udp packet revecived from client fail, err: %s", err)
+ // }
+ return
}
- clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
- conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
- if err != nil {
- s.log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
- continue
+ localAddr := *s.cfg.Parent
+ if v, ok := s.udpConns.Get(srcAddr); !ok {
+ _srcAddr, _ := net.ResolveUDPAddr("udp", srcAddr)
+ zeroAddr, _ := net.ResolveUDPAddr("udp", ":")
+ _localAddr, _ := net.ResolveUDPAddr("udp", localAddr)
+ var c *net.UDPConn
+ c, err = net.DialUDP("udp", zeroAddr, _localAddr)
+ if err != nil {
+ s.log.Printf("create local udp conn fail, err : %s", err)
+ (*inConn).Close()
+ return
+ }
+ item = &UDPConnItem{
+ conn: inConn,
+ srcAddr: _srcAddr,
+ localAddr: _localAddr,
+ udpConn: c,
+ }
+ s.udpConns.Set(srcAddr, item)
+ s.UDPRevecive(srcAddr)
+ } else {
+ item = v.(*UDPConnItem)
}
- conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = conn.Write(body)
- if err != nil {
- s.log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
- continue
- }
- //log.Debugf("send udp packet to %s success", dstAddr.String())
- buf := make([]byte, 512)
- len, _, err := conn.ReadFromUDP(buf)
- if err != nil {
- s.log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
- continue
- }
- respBody := buf[0:len]
- //log.Debugf("revecived udp packet from %s , %v", dstAddr.String(), respBody)
- _, err = (*inConn).Write(utils.UDPPacket(srcAddr, respBody))
- if err != nil {
- s.log.Printf("send udp response fail ,ERR:%s", err)
- utils.CloseConn(inConn)
- break
- }
- //s.log.Printf("send udp response success ,from:%s", dstAddr.String())
+ (*item).touchtime = time.Now().Unix()
+ go (*item).udpConn.Write(body)
}
- return
-
+}
+func (s *TCP) UDPRevecive(key string) {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.log.Printf("udp conn %s connected", key)
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ s.log.Printf("[warn] udp conn not exists for %s", key)
+ return
+ }
+ cui := v.(*UDPConnItem)
+ buf := utils.LeakyBuffer.Get()
+ defer func() {
+ utils.LeakyBuffer.Put(buf)
+ (*cui.conn).Close()
+ cui.udpConn.Close()
+ s.udpConns.Remove(key)
+ s.log.Printf("udp conn %s released", key)
+ }()
+ for {
+ n, err := cui.udpConn.Read(buf)
+ if err != nil {
+ if !utils.IsNetClosedErr(err) {
+ s.log.Printf("udp conn read udp packet fail , err: %s ", err)
+ }
+ return
+ }
+ cui.touchtime = time.Now().Unix()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ (*cui.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ _, err = (*cui.conn).Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n]))
+ (*cui.conn).SetWriteDeadline(time.Time{})
+ if err != nil {
+ cui.udpConn.Close()
+ return
+ }
+ }()
+ }
+ }()
+}
+func (s *TCP) UDPGCDeamon() {
+ gctime := int64(30)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ if s.isStop {
+ return
+ }
+ timer := time.NewTicker(time.Second)
+ for {
+ <-timer.C
+ gcKeys := []string{}
+ s.udpConns.IterCb(func(key string, v interface{}) {
+ if time.Now().Unix()-v.(*UDPConnItem).touchtime > gctime {
+ (*(v.(*UDPConnItem).conn)).Close()
+ (v.(*UDPConnItem).udpConn).Close()
+ gcKeys = append(gcKeys, key)
+ s.log.Printf("gc udp conn %s", key)
+ }
+ })
+ for _, k := range gcKeys {
+ s.udpConns.Remove(k)
+ }
+ gcKeys = nil
+ }
+ }()
}
func (s *TCP) GetParentConn() (conn net.Conn, err error) {
if *s.cfg.ParentType == "tls" {
diff --git a/services/tunnel/tunnel_bridge.go b/services/tunnel/tunnel_bridge.go
index f02981c..6dbc21b 100644
--- a/services/tunnel/tunnel_bridge.go
+++ b/services/tunnel/tunnel_bridge.go
@@ -7,10 +7,12 @@ import (
"net"
"os"
"strconv"
+ "strings"
"time"
"github.com/snail007/goproxy/services"
"github.com/snail007/goproxy/utils"
+ "github.com/snail007/goproxy/utils/mapx"
//"github.com/xtaci/smux"
smux "github.com/hashicorp/yamux"
@@ -37,8 +39,8 @@ type ServerConn struct {
}
type TunnelBridge struct {
cfg TunnelBridgeArgs
- serverConns utils.ConcurrentMap
- clientControlConns utils.ConcurrentMap
+ serverConns mapx.ConcurrentMap
+ clientControlConns mapx.ConcurrentMap
isStop bool
log *logger.Logger
}
@@ -46,8 +48,8 @@ type TunnelBridge struct {
func NewTunnelBridge() services.Service {
return &TunnelBridge{
cfg: TunnelBridgeArgs{},
- serverConns: utils.NewConcurrentMap(),
- clientControlConns: utils.NewConcurrentMap(),
+ serverConns: mapx.NewConcurrentMap(),
+ clientControlConns: mapx.NewConcurrentMap(),
isStop: false,
}
}
@@ -69,8 +71,13 @@ func (s *TunnelBridge) StopService() {
if e != nil {
s.log.Printf("stop tbridge service crashed,%s", e)
} else {
- s.log.Printf("service tbridge stopped")
+ s.log.Printf("service tbridge stoped")
}
+ s.cfg = TunnelBridgeArgs{}
+ s.clientControlConns = nil
+ s.log = nil
+ s.serverConns = nil
+ s = nil
}()
s.isStop = true
for _, sess := range s.clientControlConns.Items() {
@@ -123,10 +130,24 @@ func (s *TunnelBridge) callback(inConn net.Conn) {
s.log.Printf("mux server conn accept error,ERR:%s", err)
return
}
-
+ go func() {
+ defer func() {
+ _ = recover()
+ }()
+ timer := time.NewTicker(time.Second * 3)
+ for {
+ <-timer.C
+ if sess.NumStreams() == 0 {
+ sess.Close()
+ timer.Stop()
+ return
+ }
+ }
+ }()
var buf = make([]byte, 1024)
n, _ := inConn.Read(buf)
reader := bytes.NewReader(buf[:n])
+
//reader := bufio.NewReader(inConn)
var connType uint8
@@ -163,7 +184,7 @@ func (s *TunnelBridge) callback(inConn net.Conn) {
(*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3))
_, err := (*item.(*net.Conn)).Write(packet)
(*item.(*net.Conn)).SetWriteDeadline(time.Time{})
- if err != nil {
+ if err != nil && strings.Contains(err.Error(), "stream closed") {
s.log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
time.Sleep(time.Second * 3)
continue
diff --git a/services/tunnel/tunnel_client.go b/services/tunnel/tunnel_client.go
index d83c3f6..9747934 100644
--- a/services/tunnel/tunnel_client.go
+++ b/services/tunnel/tunnel_client.go
@@ -7,20 +7,19 @@ import (
logger "log"
"net"
"os"
+ "runtime/debug"
+ "strings"
"time"
"github.com/snail007/goproxy/services"
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/jumper"
+ "github.com/snail007/goproxy/utils/mapx"
+
//"github.com/xtaci/smux"
smux "github.com/hashicorp/yamux"
)
-const (
- CONN_SERVER_MUX = uint8(6)
- CONN_CLIENT_MUX = uint8(7)
-)
-
type TunnelClientArgs struct {
Parent *string
CertFile *string
@@ -31,24 +30,36 @@ type TunnelClientArgs struct {
Timeout *int
Jumper *string
}
+type ClientUDPConnItem struct {
+ conn *net.Conn
+ isActive bool
+ touchtime int64
+ srcAddr *net.UDPAddr
+ localAddr *net.UDPAddr
+ udpConn *net.UDPConn
+ connid string
+}
type TunnelClient struct {
cfg TunnelClientArgs
- ctrlConn net.Conn
+ ctrlConn *net.Conn
isStop bool
- userConns utils.ConcurrentMap
+ userConns mapx.ConcurrentMap
log *logger.Logger
jumper *jumper.Jumper
+ udpConns mapx.ConcurrentMap
}
func NewTunnelClient() services.Service {
return &TunnelClient{
cfg: TunnelClientArgs{},
- userConns: utils.NewConcurrentMap(),
+ userConns: mapx.NewConcurrentMap(),
isStop: false,
+ udpConns: mapx.NewConcurrentMap(),
}
}
func (s *TunnelClient) InitService() (err error) {
+ s.UDPGCDeamon()
return
}
@@ -84,12 +95,19 @@ func (s *TunnelClient) StopService() {
if e != nil {
s.log.Printf("stop tclient service crashed,%s", e)
} else {
- s.log.Printf("service tclient stopped")
+ s.log.Printf("service tclient stoped")
}
+ s.cfg = TunnelClientArgs{}
+ s.ctrlConn = nil
+ s.jumper = nil
+ s.log = nil
+ s.udpConns = nil
+ s.userConns = nil
+ s = nil
}()
s.isStop = true
if s.ctrlConn != nil {
- s.ctrlConn.Close()
+ (*s.ctrlConn).Close()
}
for _, c := range s.userConns.Items() {
(*c.(*net.Conn)).Close()
@@ -111,38 +129,50 @@ func (s *TunnelClient) Start(args interface{}, log *logger.Logger) (err error) {
return
}
if s.ctrlConn != nil {
- s.ctrlConn.Close()
+ (*s.ctrlConn).Close()
}
-
- s.ctrlConn, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key)
+ var c net.Conn
+ c, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key)
if err != nil {
s.log.Printf("control connection err: %s, retrying...", err)
time.Sleep(time.Second * 3)
- if s.ctrlConn != nil {
- s.ctrlConn.Close()
- }
continue
}
+ s.ctrlConn = &c
for {
if s.isStop {
return
}
var ID, clientLocalAddr, serverID string
- err = utils.ReadPacketData(s.ctrlConn, &ID, &clientLocalAddr, &serverID)
+ err = utils.ReadPacketData(*s.ctrlConn, &ID, &clientLocalAddr, &serverID)
if err != nil {
if s.ctrlConn != nil {
- s.ctrlConn.Close()
+ (*s.ctrlConn).Close()
}
s.log.Printf("read connection signal err: %s, retrying...", err)
break
}
- s.log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr)
+ //s.log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr)
protocol := clientLocalAddr[:3]
localAddr := clientLocalAddr[4:]
if protocol == "udp" {
- go s.ServeUDP(localAddr, ID, serverID)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.ServeUDP(localAddr, ID, serverID)
+ }()
} else {
- go s.ServeConn(localAddr, ID, serverID)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.ServeConn(localAddr, ID, serverID)
+ }()
}
}
}
@@ -183,7 +213,7 @@ func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
}
}
if err == nil {
- c, e := smux.Client(conn, &smux.Config{
+ sess, e := smux.Client(conn, &smux.Config{
AcceptBacklog: 256,
EnableKeepAlive: true,
KeepAliveInterval: 9 * time.Second,
@@ -196,12 +226,26 @@ func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
err = e
return
}
- conn, e = c.OpenStream()
+ conn, e = sess.OpenStream()
if e != nil {
s.log.Printf("mux client conn open stream error,ERR:%s", e)
err = e
return
}
+ go func() {
+ defer func() {
+ _ = recover()
+ }()
+ timer := time.NewTicker(time.Second * 3)
+ for {
+ <-timer.C
+ if sess.NumStreams() == 0 {
+ sess.Close()
+ timer.Stop()
+ return
+ }
+ }
+ }()
}
return
}
@@ -229,62 +273,141 @@ func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) {
}
// s.cm.Add(*s.cfg.Key, ID, &inConn)
s.log.Printf("conn %s created", ID)
-
+ var item *ClientUDPConnItem
+ var body []byte
+ srcAddr := ""
+ defer func() {
+ if item != nil {
+ (*(*item).conn).Close()
+ (*item).udpConn.Close()
+ s.udpConns.Remove(srcAddr)
+ inConn.Close()
+ }
+ }()
for {
if s.isStop {
return
}
- srcAddr, body, err := utils.ReadUDPPacket(inConn)
- if err == io.EOF || err == io.ErrUnexpectedEOF {
- s.log.Printf("connection %s released", ID)
- utils.CloseConn(&inConn)
- break
- } else if err != nil {
- s.log.Printf("udp packet revecived fail, err: %s", err)
- } else {
- //s.log.Printf("udp packet revecived:%s,%v", srcAddr, body)
- go s.processUDPPacket(&inConn, srcAddr, localAddr, body)
+ srcAddr, body, err = utils.ReadUDPPacket(inConn)
+ if err != nil {
+ if strings.Contains(err.Error(), "n != int(") {
+ continue
+ }
+ if !utils.IsNetDeadlineErr(err) && err != io.EOF {
+ s.log.Printf("udp packet revecived from bridge fail, err: %s", err)
+ }
+ return
}
-
+ if v, ok := s.udpConns.Get(srcAddr); !ok {
+ _srcAddr, _ := net.ResolveUDPAddr("udp", srcAddr)
+ zeroAddr, _ := net.ResolveUDPAddr("udp", ":")
+ _localAddr, _ := net.ResolveUDPAddr("udp", localAddr)
+ c, err := net.DialUDP("udp", zeroAddr, _localAddr)
+ if err != nil {
+ s.log.Printf("create local udp conn fail, err : %s", err)
+ inConn.Close()
+ return
+ }
+ item = &ClientUDPConnItem{
+ conn: &inConn,
+ srcAddr: _srcAddr,
+ localAddr: _localAddr,
+ udpConn: c,
+ connid: ID,
+ }
+ s.udpConns.Set(srcAddr, item)
+ s.UDPRevecive(srcAddr, ID)
+ } else {
+ item = v.(*ClientUDPConnItem)
+ }
+ (*item).touchtime = time.Now().Unix()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ (*item).udpConn.Write(body)
+ }()
}
- // }
}
-func (s *TunnelClient) processUDPPacket(inConn *net.Conn, srcAddr, localAddr string, body []byte) {
- dstAddr, err := net.ResolveUDPAddr("udp", localAddr)
- if err != nil {
- s.log.Printf("can't resolve address: %s", err)
- utils.CloseConn(inConn)
- return
- }
- clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
- conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
- if err != nil {
- s.log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
- return
- }
- conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = conn.Write(body)
- if err != nil {
- s.log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
- return
- }
- //s.log.Printf("send udp packet to %s success", dstAddr.String())
- buf := make([]byte, 1024)
- length, _, err := conn.ReadFromUDP(buf)
- if err != nil {
- s.log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
- return
- }
- respBody := buf[0:length]
- //s.log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody)
- bs := utils.UDPPacket(srcAddr, respBody)
- _, err = (*inConn).Write(bs)
- if err != nil {
- s.log.Printf("send udp response fail ,ERR:%s", err)
- utils.CloseConn(inConn)
- return
- }
- //s.log.Printf("send udp response success ,from:%s ,%d ,%v", dstAddr.String(), len(bs), bs)
+func (s *TunnelClient) UDPRevecive(key, ID string) {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.log.Printf("udp conn %s connected", ID)
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID)
+ return
+ }
+ cui := v.(*ClientUDPConnItem)
+ buf := utils.LeakyBuffer.Get()
+ defer func() {
+ utils.LeakyBuffer.Put(buf)
+ (*cui.conn).Close()
+ cui.udpConn.Close()
+ s.udpConns.Remove(key)
+ s.log.Printf("udp conn %s released", ID)
+ }()
+ for {
+ n, err := cui.udpConn.Read(buf)
+ if err != nil {
+ if !utils.IsNetClosedErr(err) {
+ s.log.Printf("udp conn read udp packet fail , err: %s ", err)
+ }
+ return
+ }
+ cui.touchtime = time.Now().Unix()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ (*cui.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ _, err = (*cui.conn).Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n]))
+ (*cui.conn).SetWriteDeadline(time.Time{})
+ if err != nil {
+ cui.udpConn.Close()
+ return
+ }
+ }()
+ }
+ }()
+}
+func (s *TunnelClient) UDPGCDeamon() {
+ gctime := int64(30)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ if s.isStop {
+ return
+ }
+ timer := time.NewTicker(time.Second)
+ for {
+ <-timer.C
+ gcKeys := []string{}
+ s.udpConns.IterCb(func(key string, v interface{}) {
+ if time.Now().Unix()-v.(*ClientUDPConnItem).touchtime > gctime {
+ (*(v.(*ClientUDPConnItem).conn)).Close()
+ (v.(*ClientUDPConnItem).udpConn).Close()
+ gcKeys = append(gcKeys, key)
+ s.log.Printf("gc udp conn %s", v.(*ClientUDPConnItem).connid)
+ }
+ })
+ for _, k := range gcKeys {
+ s.udpConns.Remove(k)
+ }
+ gcKeys = nil
+ }
+ }()
}
func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) {
var inConn, outConn net.Conn
diff --git a/services/tunnel/tunnel_server.go b/services/tunnel/tunnel_server.go
index 98f4904..4cdd220 100644
--- a/services/tunnel/tunnel_server.go
+++ b/services/tunnel/tunnel_server.go
@@ -15,6 +15,7 @@ import (
"github.com/snail007/goproxy/services"
"github.com/snail007/goproxy/utils"
"github.com/snail007/goproxy/utils/jumper"
+ "github.com/snail007/goproxy/utils/mapx"
//"github.com/xtaci/smux"
smux "github.com/hashicorp/yamux"
@@ -35,14 +36,19 @@ type TunnelServerArgs struct {
Mgr *TunnelServerManager
Jumper *string
}
-type UDPItem struct {
- packet *[]byte
- localAddr *net.UDPAddr
- srcAddr *net.UDPAddr
+type TunnelServer struct {
+ cfg TunnelServerArgs
+ sc utils.ServerChannel
+ isStop bool
+ udpConn *net.Conn
+ userConns mapx.ConcurrentMap
+ log *logger.Logger
+ jumper *jumper.Jumper
+ udpConns mapx.ConcurrentMap
}
+
type TunnelServerManager struct {
cfg TunnelServerArgs
- udpChn chan UDPItem
serverID string
servers []*services.Service
log *logger.Logger
@@ -51,7 +57,6 @@ type TunnelServerManager struct {
func NewTunnelServerManager() services.Service {
return &TunnelServerManager{
cfg: TunnelServerArgs{},
- udpChn: make(chan UDPItem, 50000),
serverID: utils.Uniqueid(),
servers: []*services.Service{},
}
@@ -123,6 +128,11 @@ func (s *TunnelServerManager) StopService() {
for _, server := range s.servers {
(*server).Clean()
}
+ s.cfg = TunnelServerArgs{}
+ s.log = nil
+ s.serverID = ""
+ s.servers = nil
+ s = nil
}
func (s *TunnelServerManager) CheckArgs() (err error) {
if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
@@ -136,34 +146,45 @@ func (s *TunnelServerManager) InitService() (err error) {
return
}
-type TunnelServer struct {
- cfg TunnelServerArgs
- udpChn chan UDPItem
- sc utils.ServerChannel
- isStop bool
- udpConn *net.Conn
- userConns utils.ConcurrentMap
- log *logger.Logger
- jumper *jumper.Jumper
-}
-
func NewTunnelServer() services.Service {
return &TunnelServer{
cfg: TunnelServerArgs{},
- udpChn: make(chan UDPItem, 50000),
isStop: false,
- userConns: utils.NewConcurrentMap(),
+ userConns: mapx.NewConcurrentMap(),
+ udpConns: mapx.NewConcurrentMap(),
}
}
+type TunnelUDPPacketItem struct {
+ packet *[]byte
+ localAddr *net.UDPAddr
+ srcAddr *net.UDPAddr
+}
+type TunnelUDPConnItem struct {
+ conn *net.Conn
+ isActive bool
+ touchtime int64
+ srcAddr *net.UDPAddr
+ localAddr *net.UDPAddr
+ connid string
+}
+
func (s *TunnelServer) StopService() {
defer func() {
e := recover()
if e != nil {
s.log.Printf("stop server service crashed,%s", e)
} else {
- s.log.Printf("service server stopped")
+ s.log.Printf("service server stoped")
}
+ s.cfg = TunnelServerArgs{}
+ s.jumper = nil
+ s.log = nil
+ s.sc = utils.ServerChannel{}
+ s.udpConn = nil
+ s.udpConns = nil
+ s.userConns = nil
+ s = nil
}()
s.isStop = true
@@ -181,7 +202,7 @@ func (s *TunnelServer) StopService() {
}
}
func (s *TunnelServer) InitService() (err error) {
- s.UDPConnDeamon()
+ s.UDPGCDeamon()
return
}
func (s *TunnelServer) CheckArgs() (err error) {
@@ -214,12 +235,8 @@ func (s *TunnelServer) Start(args interface{}, log *logger.Logger) (err error) {
p, _ := strconv.Atoi(port)
s.sc = utils.NewServerChannel(host, p, s.log)
if *s.cfg.IsUDP {
- err = s.sc.ListenUDP(func(packet []byte, localAddr, srcAddr *net.UDPAddr) {
- s.udpChn <- UDPItem{
- packet: &packet,
- localAddr: localAddr,
- srcAddr: srcAddr,
- }
+ err = s.sc.ListenUDP(func(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr) {
+ s.UDPSend(packet, localAddr, srcAddr)
})
if err != nil {
return
@@ -346,89 +363,120 @@ func (s *TunnelServer) GetConn() (conn net.Conn, err error) {
}
return
}
-func (s *TunnelServer) UDPConnDeamon() {
+func (s *TunnelServer) UDPGCDeamon() {
+ gctime := int64(30)
go func() {
defer func() {
- if err := recover(); err != nil {
- s.log.Printf("udp conn deamon crashed with err : %s \nstack: %s", err, string(debug.Stack()))
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
}
}()
- var outConn net.Conn
- // var hb utils.HeartbeatReadWriter
- var ID string
- // var cmdChn = make(chan bool, 1000)
- var err error
+ if s.isStop {
+ return
+ }
+ timer := time.NewTicker(time.Second)
for {
- if s.isStop {
- return
- }
- item := <-s.udpChn
- RETRY:
- if s.isStop {
- return
- }
- if outConn == nil {
- for {
- if s.isStop {
- return
- }
- outConn, ID, err = s.GetOutConn(CONN_SERVER)
- if err != nil {
- // cmdChn <- true
- outConn = nil
- utils.CloseConn(&outConn)
- s.log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err)
- time.Sleep(time.Second * 3)
- continue
- } else {
- go func(outConn net.Conn, ID string) {
- if s.udpConn != nil {
- (*s.udpConn).Close()
- }
- s.udpConn = &outConn
- for {
- if s.isStop {
- return
- }
- srcAddrFromConn, body, err := utils.ReadUDPPacket(outConn)
- if err == io.EOF || err == io.ErrUnexpectedEOF {
- s.log.Printf("UDP deamon connection %s exited", ID)
- break
- }
- if err != nil {
- s.log.Printf("parse revecived udp packet fail, err: %s ,%v", err, body)
- continue
- }
- //s.log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn)
- _srcAddr := strings.Split(srcAddrFromConn, ":")
- if len(_srcAddr) != 2 {
- s.log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn)
- continue
- }
- port, _ := strconv.Atoi(_srcAddr[1])
- dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port}
- _, err = s.sc.UDPListener.WriteToUDP(body, dstAddr)
- if err != nil {
- s.log.Printf("udp response to local %s fail,ERR:%s", srcAddrFromConn, err)
- continue
- }
- //s.log.Printf("udp response to local %s success , %v", srcAddrFromConn, body)
- }
- }(outConn, ID)
- break
- }
+ <-timer.C
+ gcKeys := []string{}
+ s.udpConns.IterCb(func(key string, v interface{}) {
+ if time.Now().Unix()-v.(*TunnelUDPConnItem).touchtime > gctime {
+ (*(v.(*TunnelUDPConnItem).conn)).Close()
+ gcKeys = append(gcKeys, key)
+ s.log.Printf("gc udp conn %s", v.(*TunnelUDPConnItem).connid)
}
+ })
+ for _, k := range gcKeys {
+ s.udpConns.Remove(k)
}
- outConn.SetWriteDeadline(time.Now().Add(time.Second))
- _, err = outConn.Write(utils.UDPPacket(item.srcAddr.String(), *item.packet))
- outConn.SetWriteDeadline(time.Time{})
- if err != nil {
- utils.CloseConn(&outConn)
- outConn = nil
- s.log.Printf("write udp packet to %s fail ,flush err:%s ,retrying...", *s.cfg.Parent, err)
- goto RETRY
- }
- //s.log.Printf("write packet %v", *item.packet)
+ gcKeys = nil
+ }
+ }()
+}
+func (s *TunnelServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) {
+ var (
+ uc *TunnelUDPConnItem
+ key = srcAddr.String()
+ ID string
+ err error
+ outconn net.Conn
+ )
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ outconn, ID, err = s.GetOutConn(CONN_SERVER)
+ if err != nil {
+ s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err)
+ return
+ }
+ uc = &TunnelUDPConnItem{
+ conn: &outconn,
+ srcAddr: srcAddr,
+ localAddr: localAddr,
+ connid: ID,
+ }
+ s.udpConns.Set(key, uc)
+ s.UDPRevecive(key, ID)
+ } else {
+ uc = v.(*TunnelUDPConnItem)
+ }
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ (*uc.conn).Close()
+ s.udpConns.Remove(key)
+ s.log.Printf("udp sender crashed with error : %s", e)
+ }
+ }()
+ uc.touchtime = time.Now().Unix()
+ (*uc.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ _, err = (*uc.conn).Write(utils.UDPPacket(srcAddr.String(), data))
+ (*uc.conn).SetWriteDeadline(time.Time{})
+ if err != nil {
+ s.log.Printf("write udp packet to %s fail ,flush err:%s ", *s.cfg.Parent, err)
+ }
+ }()
+}
+func (s *TunnelServer) UDPRevecive(key, ID string) {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.log.Printf("udp conn %s connected", ID)
+ var uc *TunnelUDPConnItem
+ defer func() {
+ if uc != nil {
+ (*uc.conn).Close()
+ }
+ s.udpConns.Remove(key)
+ s.log.Printf("udp conn %s released", ID)
+ }()
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID)
+ return
+ }
+ uc = v.(*TunnelUDPConnItem)
+ for {
+ _, body, err := utils.ReadUDPPacket(*uc.conn)
+ if err != nil {
+ if strings.Contains(err.Error(), "n != int(") {
+ continue
+ }
+ if err != io.EOF {
+ s.log.Printf("udp conn read udp packet fail , err: %s ", err)
+ }
+ return
+ }
+ uc.touchtime = time.Now().Unix()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.sc.UDPListener.WriteToUDP(body, uc.srcAddr)
+ }()
}
}()
}
diff --git a/services/udp/udp.go b/services/udp/udp.go
index cf2d46f..d6b32dc 100644
--- a/services/udp/udp.go
+++ b/services/udp/udp.go
@@ -1,9 +1,8 @@
package udp
import (
- "bufio"
+ "crypto/tls"
"fmt"
- "hash/crc32"
"io"
logger "log"
"net"
@@ -13,8 +12,8 @@ import (
"time"
"github.com/snail007/goproxy/services"
- "github.com/snail007/goproxy/services/kcpcfg"
"github.com/snail007/goproxy/utils"
+ "github.com/snail007/goproxy/utils/mapx"
)
type UDPArgs struct {
@@ -29,28 +28,44 @@ type UDPArgs struct {
CheckParentInterval *int
}
type UDP struct {
- p utils.ConcurrentMap
- outPool utils.OutConn
- cfg UDPArgs
- sc *utils.ServerChannel
- isStop bool
- log *logger.Logger
+ p mapx.ConcurrentMap
+ cfg UDPArgs
+ sc *utils.ServerChannel
+ isStop bool
+ log *logger.Logger
+ outUDPConnCtxMap mapx.ConcurrentMap
+ udpConns mapx.ConcurrentMap
+ dstAddr *net.UDPAddr
+}
+type UDPConnItem struct {
+ conn *net.Conn
+ touchtime int64
+ srcAddr *net.UDPAddr
+ localAddr *net.UDPAddr
+ connid string
+}
+type outUDPConnCtx struct {
+ localAddr *net.UDPAddr
+ srcAddr *net.UDPAddr
+ udpconn *net.UDPConn
+ touchtime int64
}
func NewUDP() services.Service {
return &UDP{
- outPool: utils.OutConn{},
- p: utils.NewConcurrentMap(),
- isStop: false,
+ p: mapx.NewConcurrentMap(),
+ isStop: false,
+ outUDPConnCtxMap: mapx.NewConcurrentMap(),
+ udpConns: mapx.NewConcurrentMap(),
}
}
func (s *UDP) CheckArgs() (err error) {
- if *s.cfg.Parent == "" {
+ if len(*s.cfg.Parent) == 0 {
err = fmt.Errorf("parent required for udp %s", *s.cfg.Local)
return
}
if *s.cfg.ParentType == "" {
- err = fmt.Errorf("parent type unkown,use -T ")
+ err = fmt.Errorf("parent type unkown,use -T ")
return
}
if *s.cfg.ParentType == "tls" {
@@ -59,12 +74,17 @@ func (s *UDP) CheckArgs() (err error) {
return
}
}
+
+ s.dstAddr, err = net.ResolveUDPAddr("udp", *s.cfg.Parent)
+ if err != nil {
+ s.log.Printf("resolve udp addr %s fail fail,ERR:%s", *s.cfg.Parent, err)
+ return
+ }
return
}
func (s *UDP) InitService() (err error) {
- if *s.cfg.ParentType != "udp" {
- s.InitOutConnPool()
- }
+ s.OutToUDPGCDeamon()
+ s.UDPGCDeamon()
return
}
func (s *UDP) StopService() {
@@ -73,8 +93,13 @@ func (s *UDP) StopService() {
if e != nil {
s.log.Printf("stop udp service crashed,%s", e)
} else {
- s.log.Printf("service udp stopped")
+ s.log.Printf("service udp stoped")
}
+ s.cfg = UDPArgs{}
+ s.log = nil
+ s.p = nil
+ s.sc = nil
+ s = nil
}()
s.isStop = true
if s.sc.Listener != nil && *s.sc.Listener != nil {
@@ -109,32 +134,28 @@ func (s *UDP) Start(args interface{}, log *logger.Logger) (err error) {
func (s *UDP) Clean() {
s.StopService()
}
-func (s *UDP) callback(packet []byte, localAddr, srcAddr *net.UDPAddr) {
+func (s *UDP) callback(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr) {
defer func() {
if err := recover(); err != nil {
s.log.Printf("udp conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack()))
}
}()
- var err error
switch *s.cfg.ParentType {
case "tcp":
fallthrough
case "tls":
- err = s.OutToTCP(packet, localAddr, srcAddr)
+ s.OutToTCP(packet, localAddr, srcAddr)
case "udp":
- err = s.OutToUDP(packet, localAddr, srcAddr)
+ s.OutToUDP(packet, localAddr, srcAddr)
default:
- err = fmt.Errorf("unkown parent type %s", *s.cfg.ParentType)
- }
- if err != nil {
- s.log.Printf("connect to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err)
+ s.log.Printf("unkown parent type %s", *s.cfg.ParentType)
}
}
func (s *UDP) GetConn(connKey string) (conn net.Conn, isNew bool, err error) {
isNew = !s.p.Has(connKey)
var _conn interface{}
if isNew {
- _conn, err = s.outPool.Get()
+ _conn, err = s.GetParentConn()
if err != nil {
return nil, false, err
}
@@ -145,117 +166,231 @@ func (s *UDP) GetConn(connKey string) (conn net.Conn, isNew bool, err error) {
conn = _conn.(net.Conn)
return
}
-func (s *UDP) OutToTCP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
- numLocal := crc32.ChecksumIEEE([]byte(localAddr.String()))
- numSrc := crc32.ChecksumIEEE([]byte(srcAddr.String()))
- mod := uint32(10)
- if mod == 0 {
- mod = 10
- }
- connKey := uint64((numLocal/10)*10 + numSrc%mod)
- conn, isNew, err := s.GetConn(fmt.Sprintf("%d", connKey))
- if err != nil {
- s.log.Printf("upd get conn to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err)
- return
- }
- if isNew {
- go func() {
- defer func() {
- if err := recover(); err != nil {
- s.log.Printf("udp conn handler out to tcp crashed with err : %s \nstack: %s", err, string(debug.Stack()))
- }
- }()
- s.log.Printf("conn %d created , local: %s", connKey, srcAddr.String())
- for {
- if s.isStop {
- conn.Close()
- return
- }
- srcAddrFromConn, body, err := utils.ReadUDPPacket(bufio.NewReader(conn))
- if err == io.EOF || err == io.ErrUnexpectedEOF {
- //s.log.Printf("connection %d released", connKey)
- s.p.Remove(fmt.Sprintf("%d", connKey))
- break
- }
- if err != nil {
- s.log.Printf("parse revecived udp packet fail, err: %s", err)
- continue
- }
- //s.log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn)
- _srcAddr := strings.Split(srcAddrFromConn, ":")
- if len(_srcAddr) != 2 {
- s.log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn)
- continue
- }
- port, _ := strconv.Atoi(_srcAddr[1])
- dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port}
- _, err = s.sc.UDPListener.WriteToUDP(body, dstAddr)
- if err != nil {
- s.log.Printf("udp response to local %s fail,ERR:%s", srcAddr, err)
- continue
- }
- //s.log.Printf("udp response to local %s success", srcAddr)
+func (s *UDP) OutToTCP(data []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
+ s.UDPSend(data, localAddr, srcAddr)
+ return
+}
+func (s *UDP) OutToUDPGCDeamon() {
+ gctime := int64(30)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
}
}()
+ if s.isStop {
+ return
+ }
+ timer := time.NewTicker(time.Second)
+ for {
+ <-timer.C
+ gcKeys := []string{}
+ s.outUDPConnCtxMap.IterCb(func(key string, v interface{}) {
+ if time.Now().Unix()-v.(*outUDPConnCtx).touchtime > gctime {
+ (*(v.(*outUDPConnCtx).udpconn)).Close()
+ gcKeys = append(gcKeys, key)
+ s.log.Printf("gc udp conn %s <--> %s", (*v.(*outUDPConnCtx)).srcAddr, (*v.(*outUDPConnCtx)).localAddr)
+ }
+ })
+ for _, k := range gcKeys {
+ s.outUDPConnCtxMap.Remove(k)
+ }
+ gcKeys = nil
+ }
+ }()
+}
+func (s *UDP) OutToUDP(packet []byte, localAddr, srcAddr *net.UDPAddr) {
+ var ouc *outUDPConnCtx
+ if v, ok := s.outUDPConnCtxMap.Get(srcAddr.String()); !ok {
+ clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
+ conn, err := net.DialUDP("udp", clientSrcAddr, s.dstAddr)
+ if err != nil {
+ s.log.Printf("connect to udp %s fail,ERR:%s", s.dstAddr.String(), err)
+
+ }
+ ouc = &outUDPConnCtx{
+ localAddr: localAddr,
+ srcAddr: srcAddr,
+ udpconn: conn,
+ }
+ s.outUDPConnCtxMap.Set(srcAddr.String(), ouc)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.log.Printf("udp conn %s <--> %s connected", srcAddr.String(), localAddr.String())
+ buf := utils.LeakyBuffer.Get()
+ defer func() {
+ utils.LeakyBuffer.Put(buf)
+ s.outUDPConnCtxMap.Remove(srcAddr.String())
+ s.log.Printf("udp conn %s <--> %s released", srcAddr.String(), localAddr.String())
+ }()
+ for {
+ n, err := ouc.udpconn.Read(buf)
+ if err != nil {
+ if !utils.IsNetClosedErr(err) {
+ s.log.Printf("udp conn read udp packet fail , err: %s ", err)
+ }
+ return
+ }
+ ouc.touchtime = time.Now().Unix()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ (*(s.sc).UDPListener).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ _, err = (*(s.sc).UDPListener).WriteTo(buf[:n], srcAddr)
+ (*(s.sc).UDPListener).SetWriteDeadline(time.Time{})
+ }()
+ }
+ }()
+ } else {
+ ouc = v.(*outUDPConnCtx)
}
- //s.log.Printf("select conn %d , local: %s", connKey, srcAddr.String())
- writer := bufio.NewWriter(conn)
- //fmt.Println(conn, writer)
- writer.Write(utils.UDPPacket(srcAddr.String(), packet))
- err = writer.Flush()
- if err != nil {
- s.log.Printf("write udp packet to %s fail ,flush err:%s", *s.cfg.Parent, err)
- return
- }
- //s.log.Printf("write packet %v", packet)
+ go func() {
+ ouc.touchtime = time.Now().Unix()
+ ouc.udpconn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ ouc.udpconn.Write(packet)
+ ouc.udpconn.SetWriteDeadline(time.Time{})
+ }()
return
}
-func (s *UDP) OutToUDP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
- //s.log.Printf("udp packet revecived:%s,%v", srcAddr, packet)
- dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Parent)
- if err != nil {
- s.log.Printf("resolve udp addr %s fail fail,ERR:%s", dstAddr.String(), err)
- return
+func (s *UDP) GetParentConn() (conn net.Conn, err error) {
+ if *s.cfg.ParentType == "tls" {
+ var _conn tls.Conn
+ _conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
+ if err == nil {
+ conn = net.Conn(&_conn)
+ }
+ } else {
+ conn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout)
}
- clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
- conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
- if err != nil {
- s.log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
- return
- }
- conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
- _, err = conn.Write(packet)
- if err != nil {
- s.log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
- return
- }
- //s.log.Printf("send udp packet to %s success", dstAddr.String())
- buf := make([]byte, 512)
- len, _, err := conn.ReadFromUDP(buf)
- if err != nil {
- s.log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
- return
- }
- //s.log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody)
- _, err = s.sc.UDPListener.WriteToUDP(buf[0:len], srcAddr)
- if err != nil {
- s.log.Printf("send udp response to cluster fail ,ERR:%s", err)
- return
- }
- //s.log.Printf("send udp response to cluster success ,from:%s", dstAddr.String())
return
}
-func (s *UDP) InitOutConnPool() {
- if *s.cfg.ParentType == "tls" || *s.cfg.ParentType == "tcp" {
- //dur int, isTLS bool, certBytes, keyBytes []byte,
- //parent string, timeout int, InitialCap int, MaxCap int
- s.outPool = utils.NewOutConn(
- *s.cfg.CheckParentInterval,
- *s.cfg.ParentType,
- kcpcfg.KCPConfigArgs{},
- s.cfg.CertBytes, s.cfg.KeyBytes, nil,
- *s.cfg.Parent,
- *s.cfg.Timeout,
- )
- }
+func (s *UDP) UDPGCDeamon() {
+ gctime := int64(30)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ if s.isStop {
+ return
+ }
+ timer := time.NewTicker(time.Second)
+ for {
+ <-timer.C
+ gcKeys := []string{}
+ s.udpConns.IterCb(func(key string, v interface{}) {
+ if time.Now().Unix()-v.(*UDPConnItem).touchtime > gctime {
+ (*(v.(*UDPConnItem).conn)).Close()
+ gcKeys = append(gcKeys, key)
+ s.log.Printf("gc udp conn %s", v.(*UDPConnItem).connid)
+ }
+ })
+ for _, k := range gcKeys {
+ s.udpConns.Remove(k)
+ }
+ gcKeys = nil
+ }
+ }()
+}
+func (s *UDP) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) {
+ var (
+ uc *UDPConnItem
+ key = srcAddr.String()
+ err error
+ outconn net.Conn
+ )
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ for {
+ outconn, err = s.GetParentConn()
+ if err != nil && strings.Contains(err.Error(), "can not connect at same time") {
+ time.Sleep(time.Millisecond * 500)
+ continue
+ } else {
+ break
+ }
+ }
+ if err != nil {
+ s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err)
+ return
+ }
+ uc = &UDPConnItem{
+ conn: &outconn,
+ srcAddr: srcAddr,
+ localAddr: localAddr,
+ }
+ s.udpConns.Set(key, uc)
+ s.UDPRevecive(key)
+ } else {
+ uc = v.(*UDPConnItem)
+ }
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ (*uc.conn).Close()
+ s.udpConns.Remove(key)
+ s.log.Printf("udp sender crashed with error : %s", e)
+ }
+ }()
+ uc.touchtime = time.Now().Unix()
+ (*uc.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
+ _, err = (*uc.conn).Write(utils.UDPPacket(fmt.Sprintf("%s", srcAddr.String()), data))
+ (*uc.conn).SetWriteDeadline(time.Time{})
+ if err != nil {
+ s.log.Printf("write udp packet to %s fail ,flush err:%s ", *s.cfg.Parent, err)
+ }
+ }()
+}
+func (s *UDP) UDPRevecive(key string) {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.log.Printf("udp conn %s connected", key)
+ var uc *UDPConnItem
+ defer func() {
+ if uc != nil {
+ (*uc.conn).Close()
+ }
+ s.udpConns.Remove(key)
+ s.log.Printf("udp conn %s released", key)
+ }()
+ v, ok := s.udpConns.Get(key)
+ if !ok {
+ s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key)
+ return
+ }
+ uc = v.(*UDPConnItem)
+ for {
+ _, body, err := utils.ReadUDPPacket(*uc.conn)
+ if err != nil {
+ if strings.Contains(err.Error(), "n != int(") {
+ continue
+ }
+ if err != io.EOF && !utils.IsNetClosedErr(err) {
+ s.log.Printf("udp conn read udp packet fail , err: %s ", err)
+ }
+ return
+ }
+ uc.touchtime = time.Now().Unix()
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ s.sc.UDPListener.WriteToUDP(body, uc.srcAddr)
+ }()
+ }
+ }()
}
diff --git a/utils/conncrypt/conncrypt.go b/utils/conncrypt/conncrypt.go
index 82641e8..f3115e7 100644
--- a/utils/conncrypt/conncrypt.go
+++ b/utils/conncrypt/conncrypt.go
@@ -12,8 +12,8 @@ import (
)
//Confg defaults
-const DefaultIterations = 1024
-const DefaultKeySize = 24 //256bits
+const DefaultIterations = 2048
+const DefaultKeySize = 32 //256bits
var DefaultHashFunc = sha256.New
var DefaultSalt = []byte(`
(;QUHj.BQ?RXzYSO]ifkXp/G!kFmWyXyEV6Nt!d|@bo+N$L9+ EB:
+ return fmt.Sprintf("%.1f EB", b.EBytes())
+ case b > PB:
+ return fmt.Sprintf("%.1f PB", b.PBytes())
+ case b > TB:
+ return fmt.Sprintf("%.1f TB", b.TBytes())
+ case b > GB:
+ return fmt.Sprintf("%.1f GB", b.GBytes())
+ case b > MB:
+ return fmt.Sprintf("%.1f MB", b.MBytes())
+ case b > KB:
+ return fmt.Sprintf("%.1f KB", b.KBytes())
+ default:
+ return fmt.Sprintf("%d B", b)
+ }
+}
+
+func (b ByteSize) MarshalText() ([]byte, error) {
+ return []byte(b.String()), nil
+}
+
+func (b *ByteSize) UnmarshalText(t []byte) error {
+ var val uint64
+ var unit string
+
+ // copy for error message
+ t0 := t
+
+ var c byte
+ var i int
+
+ParseLoop:
+ for i < len(t) {
+ c = t[i]
+ switch {
+ case '0' <= c && c <= '9':
+ if val > cutoff {
+ goto Overflow
+ }
+
+ c = c - '0'
+ val *= 10
+
+ if val > val+uint64(c) {
+ // val+v overflows
+ goto Overflow
+ }
+ val += uint64(c)
+ i++
+
+ default:
+ if i == 0 {
+ goto SyntaxError
+ }
+ break ParseLoop
+ }
+ }
+
+ unit = strings.TrimSpace(string(t[i:]))
+ switch unit {
+ case "Kb", "Mb", "Gb", "Tb", "Pb", "Eb":
+ goto BitsError
+ }
+ unit = strings.ToLower(unit)
+ switch unit {
+ case "", "b", "byte":
+ // do nothing - already in bytes
+
+ case "k", "kb", "kilo", "kilobyte", "kilobytes":
+ if val > maxUint64/uint64(KB) {
+ goto Overflow
+ }
+ val *= uint64(KB)
+
+ case "m", "mb", "mega", "megabyte", "megabytes":
+ if val > maxUint64/uint64(MB) {
+ goto Overflow
+ }
+ val *= uint64(MB)
+
+ case "g", "gb", "giga", "gigabyte", "gigabytes":
+ if val > maxUint64/uint64(GB) {
+ goto Overflow
+ }
+ val *= uint64(GB)
+
+ case "t", "tb", "tera", "terabyte", "terabytes":
+ if val > maxUint64/uint64(TB) {
+ goto Overflow
+ }
+ val *= uint64(TB)
+
+ case "p", "pb", "peta", "petabyte", "petabytes":
+ if val > maxUint64/uint64(PB) {
+ goto Overflow
+ }
+ val *= uint64(PB)
+
+ case "E", "EB", "e", "eb", "eB":
+ if val > maxUint64/uint64(EB) {
+ goto Overflow
+ }
+ val *= uint64(EB)
+
+ default:
+ goto SyntaxError
+ }
+
+ *b = ByteSize(val)
+ return nil
+
+Overflow:
+ *b = ByteSize(maxUint64)
+ return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrRange}
+
+SyntaxError:
+ *b = 0
+ return &strconv.NumError{fnUnmarshalText, string(t0), strconv.ErrSyntax}
+
+BitsError:
+ *b = 0
+ return &strconv.NumError{fnUnmarshalText, string(t0), ErrBits}
+}
diff --git a/utils/dnsx/resolver.go b/utils/dnsx/resolver.go
new file mode 100644
index 0000000..71559f4
--- /dev/null
+++ b/utils/dnsx/resolver.go
@@ -0,0 +1,114 @@
+package dnsx
+
+import (
+ "fmt"
+ logger "log"
+ "net"
+ "strings"
+ "time"
+
+ "github.com/snail007/goproxy/utils/mapx"
+ dns "github.com/miekg/dns"
+)
+
+type DomainResolver struct {
+ ttl int
+ dnsAddrress string
+ data mapx.ConcurrentMap
+ log *logger.Logger
+}
+type DomainResolverItem struct {
+ ip string
+ domain string
+ expiredAt int64
+}
+
+func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver {
+ return DomainResolver{
+ ttl: ttl,
+ dnsAddrress: dnsAddrress,
+ data: mapx.NewConcurrentMap(),
+ log: log,
+ }
+}
+func (a *DomainResolver) DnsAddress() (address string) {
+ address = a.dnsAddrress
+ return
+}
+func (a *DomainResolver) MustResolve(address string) (ip string) {
+ ip, _ = a.Resolve(address)
+ return
+}
+func (a *DomainResolver) Resolve(address string) (ip string, err error) {
+ domain := address
+ port := ""
+ fromCache := "false"
+ defer func() {
+ if port != "" {
+ ip = net.JoinHostPort(ip, port)
+ }
+ a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache)
+ //a.PrintData()
+ }()
+ if strings.Contains(domain, ":") {
+ domain, port, err = net.SplitHostPort(domain)
+ if err != nil {
+ return
+ }
+ }
+ if net.ParseIP(domain) != nil {
+ ip = domain
+ fromCache = "ip ignore"
+ return
+ }
+ item, ok := a.data.Get(domain)
+ if ok {
+ //log.Println("find ", domain)
+ if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
+ ip = (*item.(*DomainResolverItem)).ip
+ fromCache = "true"
+ //log.Println("from cache ", domain)
+ return
+ }
+ } else {
+ item = &DomainResolverItem{
+ domain: domain,
+ }
+
+ }
+ c := new(dns.Client)
+ c.DialTimeout = time.Millisecond * 5000
+ c.ReadTimeout = time.Millisecond * 5000
+ c.WriteTimeout = time.Millisecond * 5000
+ m := new(dns.Msg)
+ m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
+ m.RecursionDesired = true
+ r, _, err := c.Exchange(m, a.dnsAddrress)
+ if r == nil {
+ return
+ }
+ if r.Rcode != dns.RcodeSuccess {
+ err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
+ return
+ }
+ for _, answer := range r.Answer {
+ if answer.Header().Rrtype == dns.TypeA {
+ info := strings.Fields(answer.String())
+ if len(info) >= 5 {
+ ip = info[4]
+ _item := item.(*DomainResolverItem)
+ (*_item).expiredAt = time.Now().Unix() + int64(a.ttl)
+ (*_item).ip = ip
+ a.data.Set(domain, item)
+ return
+ }
+ }
+ }
+ return
+}
+func (a *DomainResolver) PrintData() {
+ for k, item := range a.data.Items() {
+ d := item.(*DomainResolverItem)
+ a.log.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt)
+ }
+}
diff --git a/utils/functions.go b/utils/functions.go
index 9cca81a..a00d334 100755
--- a/utils/functions.go
+++ b/utils/functions.go
@@ -3,10 +3,13 @@ package utils
import (
"bufio"
"bytes"
+ "context"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
+ "encoding/base64"
"encoding/binary"
+ "encoding/hex"
"encoding/pem"
"errors"
"fmt"
@@ -17,14 +20,16 @@ import (
"net"
"net/http"
"os"
+ "os/exec"
+ "strings"
"github.com/snail007/goproxy/services/kcpcfg"
+ "github.com/snail007/goproxy/utils/lb"
"golang.org/x/crypto/pbkdf2"
- "context"
"strconv"
- "strings"
+
"time"
"github.com/snail007/goproxy/utils/id"
@@ -33,6 +38,12 @@ import (
)
func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) {
+ ioBind(dst, src, fn, log, true)
+}
+func IoBindNoClose(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) {
+ ioBind(dst, src, fn, log, false)
+}
+func ioBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger, close bool) {
go func() {
defer func() {
if err := recover(); err != nil {
@@ -68,20 +79,41 @@ func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interfac
case err = <-e2:
//log.Printf("e2")
}
- src.Close()
- dst.Close()
+ func() {
+ defer func() {
+ _ = recover()
+ }()
+ if close {
+ src.Close()
+ }
+ }()
+ func() {
+ defer func() {
+ _ = recover()
+ }()
+ if close {
+ dst.Close()
+ }
+ }()
if fn != nil {
fn(err)
}
}()
}
func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) {
+ defer func() {
+ if e := recover(); e != nil {
+ }
+ }()
buf := LeakyBuffer.Get()
defer LeakyBuffer.Put(buf)
n := 0
for {
n, err = src.Read(buf)
if n > 0 {
+ if n > len(buf) {
+ n = len(buf)
+ }
if _, e := dst.Write(buf[0:n]); e != nil {
return e
}
@@ -122,6 +154,7 @@ func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Con
caBytes := certBytes
if caCertBytes != nil {
caBytes = caCertBytes
+
}
ok := serverCertPool.AppendCertsFromPEM(caBytes)
if !ok {
@@ -212,6 +245,98 @@ func CloseConn(conn *net.Conn) {
(*conn).Close()
}
}
+func Keygen() (err error) {
+ CList := []string{"AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BL", "BM", "BN", "BO", "BR", "BS", "BW", "BY", "BZ", "CA", "CF", "CG", "CH", "CK", "CL", "CM", "CN", "CO", "CR", "CS", "CU", "CY", "CZ", "DE", "DJ", "DK", "DO", "DZ", "EC", "EE", "EG", "ES", "ET", "FI", "FJ", "FR", "GA", "GB", "GD", "GE", "GF", "GH", "GI", "GM", "GN", "GR", "GT", "GU", "GY", "HK", "HN", "HT", "HU", "ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG", "KH", "KP", "KR", "KT", "KW", "KZ", "LA", "LB", "LC", "LI", "LK", "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "MG", "ML", "MM", "MN", "MO", "MS", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE", "NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PF", "PG", "PH", "PK", "PL", "PR", "PT", "PY", "QA", "RO", "RU", "SA", "SB", "SC", "SD", "SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ", "TD", "TG", "TH", "TJ", "TM", "TN", "TO", "TR", "TT", "TW", "TZ", "UA", "UG", "US", "UY", "UZ", "VC", "VE", "VN", "YE", "YU", "ZA", "ZM", "ZR", "ZW"}
+ domainSubfixList := []string{".com", ".edu", ".gov", ".int", ".mil", ".net", ".org", ".biz", ".info", ".pro", ".name", ".museum", ".coop", ".aero", ".xxx", ".idv", ".ac", ".ad", ".ae", ".af", ".ag", ".ai", ".al", ".am", ".an", ".ao", ".aq", ".ar", ".as", ".at", ".au", ".aw", ".az", ".ba", ".bb", ".bd", ".be", ".bf", ".bg", ".bh", ".bi", ".bj", ".bm", ".bn", ".bo", ".br", ".bs", ".bt", ".bv", ".bw", ".by", ".bz", ".ca", ".cc", ".cd", ".cf", ".cg", ".ch", ".ci", ".ck", ".cl", ".cm", ".cn", ".co", ".cr", ".cu", ".cv", ".cx", ".cy", ".cz", ".de", ".dj", ".dk", ".dm", ".do", ".dz", ".ec", ".ee", ".eg", ".eh", ".er", ".es", ".et", ".eu", ".fi", ".fj", ".fk", ".fm", ".fo", ".fr", ".ga", ".gd", ".ge", ".gf", ".gg", ".gh", ".gi", ".gl", ".gm", ".gn", ".gp", ".gq", ".gr", ".gs", ".gt", ".gu", ".gw", ".gy", ".hk", ".hm", ".hn", ".hr", ".ht", ".hu", ".id", ".ie", ".il", ".im", ".in", ".io", ".iq", ".ir", ".is", ".it", ".je", ".jm", ".jo", ".jp", ".ke", ".kg", ".kh", ".ki", ".km", ".kn", ".kp", ".kr", ".kw", ".ky", ".kz", ".la", ".lb", ".lc", ".li", ".lk", ".lr", ".ls", ".lt", ".lu", ".lv", ".ly", ".ma", ".mc", ".md", ".mg", ".mh", ".mk", ".ml", ".mm", ".mn", ".mo", ".mp", ".mq", ".mr", ".ms", ".mt", ".mu", ".mv", ".mw", ".mx", ".my", ".mz", ".na", ".nc", ".ne", ".nf", ".ng", ".ni", ".nl", ".no", ".np", ".nr", ".nu", ".nz", ".om", ".pa", ".pe", ".pf", ".pg", ".ph", ".pk", ".pl", ".pm", ".pn", ".pr", ".ps", ".pt", ".pw", ".py", ".qa", ".re", ".ro", ".ru", ".rw", ".sa", ".sb", ".sc", ".sd", ".se", ".sg", ".sh", ".si", ".sj", ".sk", ".sl", ".sm", ".sn", ".so", ".sr", ".st", ".sv", ".sy", ".sz", ".tc", ".td", ".tf", ".tg", ".th", ".tj", ".tk", ".tl", ".tm", ".tn", ".to", ".tp", ".tr", ".tt", ".tv", ".tw", ".tz", ".ua", ".ug", ".uk", ".um", ".us", ".uy", ".uz", ".va", ".vc", ".ve", ".vg", ".vi", ".vn", ".vu", ".wf", ".ws", ".ye", ".yt", ".yu", ".yr", ".za", ".zm", ".zw"}
+ C := CList[int(RandInt(4))%len(CList)]
+ ST := RandString(int(RandInt(4) % 10))
+ O := RandString(int(RandInt(4) % 10))
+ CN := strings.ToLower(RandString(int(RandInt(4)%10)) + domainSubfixList[int(RandInt(4))%len(domainSubfixList)])
+ //log.Printf("C: %s, ST: %s, O: %s, CN: %s", C, ST, O, CN)
+ var out []byte
+ if len(os.Args) == 3 && os.Args[2] == "ca" {
+ cmd := exec.Command("sh", "-c", "openssl genrsa -out ca.key 2048")
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ logger.Printf("err:%s", err)
+ return
+ }
+ fmt.Println(string(out))
+
+ cmdStr := fmt.Sprintf("openssl req -new -key ca.key -x509 -days 36500 -out ca.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, "*."+CN)
+ cmd = exec.Command("sh", "-c", cmdStr)
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ logger.Printf("err:%s", err)
+ return
+ }
+ fmt.Println(string(out))
+ } else if len(os.Args) == 5 && os.Args[2] == "ca" && os.Args[3] != "" && os.Args[4] != "" {
+ certBytes, _ := ioutil.ReadFile("ca.crt")
+ block, _ := pem.Decode(certBytes)
+ if block == nil || certBytes == nil {
+ panic("failed to parse ca certificate PEM")
+ }
+ x509Cert, _ := x509.ParseCertificate(block.Bytes)
+ if x509Cert == nil {
+ panic("failed to parse block")
+ }
+ name := os.Args[3]
+ days := os.Args[4]
+ cmd := exec.Command("sh", "-c", "openssl genrsa -out "+name+".key 2048")
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ logger.Printf("err:%s", err)
+ return
+ }
+ fmt.Println(string(out))
+
+ cmdStr := fmt.Sprintf("openssl req -new -key %s.key -out %s.csr -subj /C=%s/ST=%s/O=%s/CN=%s", name, name, C, ST, O, CN)
+ fmt.Printf("%s", cmdStr)
+ cmd = exec.Command("sh", "-c", cmdStr)
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ logger.Printf("err:%s", err)
+ return
+ }
+ fmt.Println(string(out))
+
+ cmdStr = fmt.Sprintf("openssl x509 -req -days %s -in %s.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out %s.crt", days, name, name)
+ fmt.Printf("%s", cmdStr)
+ cmd = exec.Command("sh", "-c", cmdStr)
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ logger.Printf("err:%s", err)
+ return
+ }
+
+ fmt.Println(string(out))
+ } else if len(os.Args) == 3 && os.Args[2] == "usage" {
+ fmt.Println(`proxy keygen //generate proxy.crt and proxy.key
+proxy keygen ca //generate ca.crt and ca.key
+proxy keygen ca client0 30 //generate client0.crt client0.key and use ca.crt sign it with 30 days
+ `)
+ } else if len(os.Args) == 2 {
+ cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ logger.Printf("err:%s", err)
+ return
+ }
+ fmt.Println(string(out))
+
+ cmdStr := fmt.Sprintf("openssl req -new -key proxy.key -x509 -days 36500 -out proxy.crt -subj /C=%s/ST=%s/O=%s/CN=%s", C, ST, O, CN)
+ cmd = exec.Command("sh", "-c", cmdStr)
+ out, err = cmd.CombinedOutput()
+ if err != nil {
+ logger.Printf("err:%s", err)
+ return
+ }
+ fmt.Println(string(out))
+ }
+
+ return
+}
var allInterfaceAddrCache []net.IP
@@ -310,10 +435,10 @@ func ReadUDPPacket(_reader io.Reader) (srcAddr string, packet []byte, err error)
return
}
func Uniqueid() string {
- return xid.New().String()
- // var src = rand.NewSource(time.Now().UnixNano())
- // s := fmt.Sprintf("%d", src.Int63())
- // return s[len(s)-5:len(s)-1] + fmt.Sprintf("%d", uint64(time.Now().UnixNano()))[8:]
+ str := fmt.Sprintf("%d%s", time.Now().UnixNano(), xid.New().String())
+ hash := sha1.New()
+ hash.Write([]byte(str))
+ return hex.EncodeToString(hash.Sum(nil))
}
func RandString(strlen int) string {
codes := "QWERTYUIOPLKJHGFDSAZXCVBNMabcdefghijklmnopqrstuvwxyz0123456789"
@@ -338,15 +463,15 @@ func RandInt(strLen int) int64 {
i, _ := strconv.ParseInt(string(data), 10, 64)
return i
}
-func ReadData(r io.Reader) (data string, err error) {
- var len uint16
+func ReadBytes(r io.Reader) (data []byte, err error) {
+ var len uint64
err = binary.Read(r, binary.LittleEndian, &len)
if err != nil {
return
}
var n int
- _data := make([]byte, len)
- n, err = r.Read(_data)
+ data = make([]byte, len)
+ n, err = r.Read(data)
if err != nil {
return
}
@@ -354,9 +479,37 @@ func ReadData(r io.Reader) (data string, err error) {
err = fmt.Errorf("error data len")
return
}
+ return
+}
+func ReadData(r io.Reader) (data string, err error) {
+ _data, err := ReadBytes(r)
+ if err != nil {
+ return
+ }
data = string(_data)
return
}
+
+//non typed packet with Bytes
+func ReadPacketBytes(r io.Reader, data ...*[]byte) (err error) {
+ for _, d := range data {
+ *d, err = ReadBytes(r)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+func BuildPacketBytes(data ...[]byte) []byte {
+ pkg := new(bytes.Buffer)
+ for _, d := range data {
+ binary.Write(pkg, binary.LittleEndian, uint64(len(d)))
+ binary.Write(pkg, binary.LittleEndian, d)
+ }
+ return pkg.Bytes()
+}
+
+//non typed packet with string
func ReadPacketData(r io.Reader, data ...*string) (err error) {
for _, d := range data {
*d, err = ReadData(r)
@@ -366,13 +519,50 @@ func ReadPacketData(r io.Reader, data ...*string) (err error) {
}
return
}
-func ReadPacket(r io.Reader, typ *uint8, data ...*string) (err error) {
+func BuildPacketData(data ...string) []byte {
+ pkg := new(bytes.Buffer)
+ for _, d := range data {
+ bytes := []byte(d)
+ binary.Write(pkg, binary.LittleEndian, uint64(len(bytes)))
+ binary.Write(pkg, binary.LittleEndian, bytes)
+ }
+ return pkg.Bytes()
+}
+
+//typed packet with bytes
+func ReadBytesPacket(r io.Reader, packetType *uint8, data ...*[]byte) (err error) {
var connType uint8
err = binary.Read(r, binary.LittleEndian, &connType)
if err != nil {
return
}
- *typ = connType
+ *packetType = connType
+ for _, d := range data {
+ *d, err = ReadBytes(r)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+func BuildBytesPacket(packetType uint8, data ...[]byte) []byte {
+ pkg := new(bytes.Buffer)
+ binary.Write(pkg, binary.LittleEndian, packetType)
+ for _, d := range data {
+ binary.Write(pkg, binary.LittleEndian, uint64(len(d)))
+ binary.Write(pkg, binary.LittleEndian, d)
+ }
+ return pkg.Bytes()
+}
+
+//typed packet with string
+func ReadPacket(r io.Reader, packetType *uint8, data ...*string) (err error) {
+ var connType uint8
+ err = binary.Read(r, binary.LittleEndian, &connType)
+ if err != nil {
+ return
+ }
+ *packetType = connType
for _, d := range data {
*d, err = ReadData(r)
if err != nil {
@@ -381,25 +571,18 @@ func ReadPacket(r io.Reader, typ *uint8, data ...*string) (err error) {
}
return
}
-func BuildPacket(typ uint8, data ...string) []byte {
+
+func BuildPacket(packetType uint8, data ...string) []byte {
pkg := new(bytes.Buffer)
- binary.Write(pkg, binary.LittleEndian, typ)
+ binary.Write(pkg, binary.LittleEndian, packetType)
for _, d := range data {
bytes := []byte(d)
- binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
- binary.Write(pkg, binary.LittleEndian, bytes)
- }
- return pkg.Bytes()
-}
-func BuildPacketData(data ...string) []byte {
- pkg := new(bytes.Buffer)
- for _, d := range data {
- bytes := []byte(d)
- binary.Write(pkg, binary.LittleEndian, uint16(len(bytes)))
+ binary.Write(pkg, binary.LittleEndian, uint64(len(bytes)))
binary.Write(pkg, binary.LittleEndian, bytes)
}
return pkg.Bytes()
}
+
func SubStr(str string, start, end int) string {
if len(str) == 0 {
return ""
@@ -419,12 +602,21 @@ func SubBytes(bytes []byte, start, end int) []byte {
return bytes[start:end]
}
func TlsBytes(cert, key string) (certBytes, keyBytes []byte, err error) {
- certBytes, err = ioutil.ReadFile(cert)
+ base64Prefix := "base64://"
+ if strings.HasPrefix(cert, base64Prefix) {
+ certBytes, err = base64.StdEncoding.DecodeString(cert[len(base64Prefix):])
+ } else {
+ certBytes, err = ioutil.ReadFile(cert)
+ }
if err != nil {
err = fmt.Errorf("err : %s", err)
return
}
- keyBytes, err = ioutil.ReadFile(key)
+ if strings.HasPrefix(key, base64Prefix) {
+ keyBytes, err = base64.StdEncoding.DecodeString(key[len(base64Prefix):])
+ } else {
+ keyBytes, err = ioutil.ReadFile(key)
+ }
if err != nil {
err = fmt.Errorf("err : %s", err)
return
@@ -495,7 +687,7 @@ func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, er
body, err = ioutil.ReadAll(resp.Body)
return
}
-func IsIternalIP(domainOrIP string, always bool) bool {
+func IsInternalIP(domainOrIP string, always bool) bool {
var outIPs []net.IP
var err error
var isDomain bool
@@ -507,7 +699,7 @@ func IsIternalIP(domainOrIP string, always bool) bool {
}
if isDomain {
- outIPs, err = MyLookupIP(domainOrIP)
+ outIPs, err = LookupIP(domainOrIP)
} else {
outIPs = []net.IP{net.ParseIP(domainOrIP)}
}
@@ -515,6 +707,7 @@ func IsIternalIP(domainOrIP string, always bool) bool {
if err != nil {
return false
}
+
for _, ip := range outIPs {
if ip.IsLoopback() {
return true
@@ -522,7 +715,7 @@ func IsIternalIP(domainOrIP string, always bool) bool {
if ip.To4().Mask(net.IPv4Mask(255, 0, 0, 0)).String() == "10.0.0.0" {
return true
}
- if ip.To4().Mask(net.IPv4Mask(255, 0, 0, 0)).String() == "192.168.0.0" {
+ if ip.To4().Mask(net.IPv4Mask(255, 255, 0, 0)).String() == "192.168.0.0" {
return true
}
if ip.To4().Mask(net.IPv4Mask(255, 0, 0, 0)).String() == "172.0.0.0" {
@@ -585,6 +778,63 @@ func RemoveProxyHeaders(head []byte) []byte {
func InsertProxyHeaders(head []byte, headers string) []byte {
return bytes.Replace(head, []byte("\r\n"), []byte("\r\n"+headers), 1)
}
+func LBMethod(key string) int {
+ typs := map[string]int{"weight": lb.SELECT_WEITHT, "leasttime": lb.SELECT_LEASTTIME, "leastconn": lb.SELECT_LEASTCONN, "hash": lb.SELECT_HASH, "roundrobin": lb.SELECT_ROUNDROBIN}
+ return typs[key]
+}
+func UDPCopy(dst, src *net.UDPConn, dstAddr net.Addr, readTimeout time.Duration, beforeWriteFn func(data []byte) []byte, deferFn func(e interface{})) {
+ go func() {
+ defer func() {
+ deferFn(recover())
+ }()
+ buf := LeakyBuffer.Get()
+ defer LeakyBuffer.Put(buf)
+ for {
+ if readTimeout > 0 {
+ src.SetReadDeadline(time.Now().Add(readTimeout))
+ }
+ n, err := src.Read(buf)
+ if readTimeout > 0 {
+ src.SetReadDeadline(time.Time{})
+ }
+ if err != nil {
+ if IsNetClosedErr(err) || IsNetTimeoutErr(err) || IsNetRefusedErr(err) {
+ return
+ }
+ continue
+ }
+ _, err = dst.WriteTo(beforeWriteFn(buf[:n]), dstAddr)
+ if err != nil {
+ if IsNetClosedErr(err) {
+ return
+ }
+ continue
+ }
+ }
+ }()
+}
+func IsNetClosedErr(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "use of closed network connection")
+}
+func IsNetTimeoutErr(err error) bool {
+ if err == nil {
+ return false
+ }
+ e, ok := err.(net.Error)
+ return ok && e.Timeout()
+}
+func IsNetRefusedErr(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "connection refused")
+}
+func IsNetDeadlineErr(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "i/o deadline reached")
+}
+func IsNetSocketNotConnectedErr(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "socket is not connected")
+}
+func NewDefaultLogger() *logger.Logger {
+ return logger.New(os.Stderr, "", logger.LstdFlags)
+}
// type sockaddr struct {
// family uint16
@@ -647,7 +897,8 @@ func InsertProxyHeaders(head []byte, headers string) []byte {
net.LookupIP may cause deadlock in windows
https://github.com/golang/go/issues/24178
*/
-func MyLookupIP(host string) ([]net.IP, error) {
+
+func LookupIP(host string) ([]net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(3))
defer func() {
diff --git a/utils/io-limiter.go b/utils/iolimiter/iolimiter.go
similarity index 51%
rename from utils/io-limiter.go
rename to utils/iolimiter/iolimiter.go
index 5162d15..8bc478f 100644
--- a/utils/io-limiter.go
+++ b/utils/iolimiter/iolimiter.go
@@ -1,8 +1,9 @@
-package utils
+package iolimiter
import (
"context"
"io"
+ "net"
"time"
"golang.org/x/time/rate"
@@ -22,6 +23,86 @@ type Writer struct {
ctx context.Context
}
+type conn struct {
+ net.Conn
+ r io.Reader
+ w io.Writer
+ readLimiter *rate.Limiter
+ writeLimiter *rate.Limiter
+ ctx context.Context
+}
+
+//NewtRateLimitConn sets rate limit (bytes/sec) to the Conn read and write.
+func NewtConn(c net.Conn, bytesPerSec float64) net.Conn {
+ s := &conn{
+ Conn: c,
+ r: c,
+ w: c,
+ ctx: context.Background(),
+ }
+ s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
+ s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
+ s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
+ s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
+ return s
+}
+
+//NewtRateLimitReaderConn sets rate limit (bytes/sec) to the Conn read.
+func NewReaderConn(c net.Conn, bytesPerSec float64) net.Conn {
+ s := &conn{
+ Conn: c,
+ r: c,
+ w: c,
+ ctx: context.Background(),
+ }
+ s.readLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
+ s.readLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
+ return s
+}
+
+//NewtRateLimitWriterConn sets rate limit (bytes/sec) to the Conn write.
+func NewWriterConn(c net.Conn, bytesPerSec float64) net.Conn {
+ s := &conn{
+ Conn: c,
+ r: c,
+ w: c,
+ ctx: context.Background(),
+ }
+ s.writeLimiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
+ s.writeLimiter.AllowN(time.Now(), burstLimit) // spend initial burst
+ return s
+}
+
+// Read reads bytes into p.
+func (s *conn) Read(p []byte) (int, error) {
+ if s.readLimiter == nil {
+ return s.r.Read(p)
+ }
+ n, err := s.r.Read(p)
+ if err != nil {
+ return n, err
+ }
+ if err := s.readLimiter.WaitN(s.ctx, n); err != nil {
+ return n, err
+ }
+ return n, nil
+}
+
+// Write writes bytes from p.
+func (s *conn) Write(p []byte) (int, error) {
+ if s.writeLimiter == nil {
+ return s.w.Write(p)
+ }
+ n, err := s.w.Write(p)
+ if err != nil {
+ return n, err
+ }
+ if err := s.writeLimiter.WaitN(s.ctx, n); err != nil {
+ return n, err
+ }
+ return n, err
+}
+
// NewReader returns a reader that implements io.Reader with rate limiting.
func NewReader(r io.Reader) *Reader {
return &Reader{
diff --git a/utils/jumper/jumper.go b/utils/jumper/jumper.go
index e66c85a..11fe4eb 100644
--- a/utils/jumper/jumper.go
+++ b/utils/jumper/jumper.go
@@ -51,7 +51,10 @@ func (j *Jumper) dialHTTPS(address string, timeout time.Duration) (conn net.Conn
}
pb := new(bytes.Buffer)
pb.Write([]byte(fmt.Sprintf("CONNECT %s HTTP/1.1\r\n", address)))
+ pb.WriteString(fmt.Sprintf("Host: %s\r\n", address))
+ pb.WriteString(fmt.Sprintf("Proxy-Host: %s\r\n", address))
pb.WriteString("Proxy-Connection: Keep-Alive\r\n")
+ pb.WriteString("Connection: Keep-Alive\r\n")
if j.proxyURL.User != nil {
p, _ := j.proxyURL.User.Password()
u := fmt.Sprintf("%s:%s", j.proxyURL.User.Username(), p)
diff --git a/utils/lb/backend.go b/utils/lb/backend.go
new file mode 100644
index 0000000..eac46ba
--- /dev/null
+++ b/utils/lb/backend.go
@@ -0,0 +1,215 @@
+package lb
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "runtime/debug"
+ "sync"
+ "time"
+
+ "github.com/snail007/goproxy/utils/dnsx"
+)
+
+// BackendConfig it's the configuration loaded
+type BackendConfig struct {
+ Address string
+
+ ActiveAfter int
+ InactiveAfter int
+ Weight int
+
+ Timeout time.Duration
+ RetryTime time.Duration
+
+ IsMuxCheck bool
+ ConnFactory func(address string, timeout time.Duration) (net.Conn, error)
+}
+type BackendsConfig []*BackendConfig
+
+// BackendControl keep the control data
+type BackendControl struct {
+ Failed bool // The last request failed
+ Active bool
+
+ InactiveTries int
+ ActiveTries int
+
+ Connections int
+
+ ConnectUsedMillisecond int
+
+ isStop bool
+}
+
+// Backend structure
+type Backend struct {
+ BackendConfig
+ BackendControl
+ sync.RWMutex
+ log *log.Logger
+ dr *dnsx.DomainResolver
+}
+
+type Backends []*Backend
+
+func NewBackend(backendConfig BackendConfig, dr *dnsx.DomainResolver, log *log.Logger) (*Backend, error) {
+
+ if backendConfig.Address == "" {
+ return nil, errors.New("Address rquired")
+ }
+ if backendConfig.ActiveAfter == 0 {
+ backendConfig.ActiveAfter = 2
+ }
+ if backendConfig.InactiveAfter == 0 {
+ backendConfig.InactiveAfter = 3
+ }
+ if backendConfig.Weight == 0 {
+ backendConfig.Weight = 1
+ }
+ if backendConfig.Timeout == 0 {
+ backendConfig.Timeout = time.Millisecond * 1500
+ }
+ if backendConfig.RetryTime == 0 {
+ backendConfig.RetryTime = time.Millisecond * 2000
+ }
+ return &Backend{
+ dr: dr,
+ log: log,
+ BackendConfig: backendConfig,
+ BackendControl: BackendControl{
+ Failed: true,
+ Active: false,
+ InactiveTries: 0,
+ ActiveTries: 0,
+ Connections: 0,
+ ConnectUsedMillisecond: 0,
+ isStop: false,
+ },
+ }, nil
+}
+func (b *Backend) StopHeartCheck() {
+ b.isStop = true
+}
+
+func (b *Backend) IncreasConns() {
+ b.RWMutex.Lock()
+ b.Connections++
+ b.RWMutex.Unlock()
+}
+
+func (b *Backend) DecreaseConns() {
+ b.RWMutex.Lock()
+ b.Connections--
+ b.RWMutex.Unlock()
+}
+
+func (b *Backend) StartHeartCheck() {
+ if b.IsMuxCheck {
+ b.startMuxHeartCheck()
+ } else {
+ b.startTCPHeartCheck()
+ }
+}
+func (b *Backend) startMuxHeartCheck() {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ for {
+ if b.isStop {
+ return
+ }
+ var c net.Conn
+ var err error
+ start := time.Now().UnixNano() / int64(time.Microsecond)
+ c, err = b.getConn()
+ b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
+ if err != nil {
+ b.Active = false
+ time.Sleep(time.Second * 2)
+ continue
+ } else {
+ b.Active = true
+ }
+ for {
+ buf := make([]byte, 1)
+ c.Read(buf)
+ buf = nil
+ break
+ }
+ b.Active = false
+ }
+ }()
+}
+
+// Monitoring the backend
+func (b *Backend) startTCPHeartCheck() {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ for {
+ if b.isStop {
+ return
+ }
+ var c net.Conn
+ var err error
+ start := time.Now().UnixNano() / int64(time.Microsecond)
+ c, err = b.getConn()
+ b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
+ if err == nil {
+ c.Close()
+ }
+ if err != nil {
+ b.RWMutex.Lock()
+ // Max tries before consider inactive
+ if b.InactiveTries >= b.InactiveAfter {
+ //b.log.Printf("Backend inactive [%s]", b.Address)
+ b.Active = false
+ b.ActiveTries = 0
+ } else {
+ // Ok that guy it's out of the game
+ b.Failed = true
+ b.InactiveTries++
+ //b.log.Printf("Error to check address [%s] tries [%d]", b.Address, b.InactiveTries)
+ }
+ b.RWMutex.Unlock()
+ } else {
+
+ // Ok, let's keep working boys
+ b.RWMutex.Lock()
+ if b.ActiveTries >= b.ActiveAfter {
+ if b.Failed {
+ //log.Printf("Backend active [%s]", b.Address)
+ }
+ b.Failed = false
+ b.Active = true
+ b.InactiveTries = 0
+ } else {
+ b.ActiveTries++
+ }
+ b.RWMutex.Unlock()
+ }
+ time.Sleep(b.RetryTime)
+ }
+ }()
+}
+func (b *Backend) getConn() (conn net.Conn, err error) {
+ address := b.Address
+ if b.dr != nil && b.dr.DnsAddress() != "" {
+ address, err = b.dr.Resolve(b.Address)
+ if err != nil {
+ b.log.Printf("dns error %s , ERR:%s", b.Address, err)
+ }
+ }
+ if b.ConnFactory != nil {
+ return b.ConnFactory(address, b.Timeout)
+ }
+ return net.DialTimeout("tcp", address, b.Timeout)
+}
diff --git a/utils/lb/lb.go b/utils/lb/lb.go
new file mode 100644
index 0000000..8d0b2b0
--- /dev/null
+++ b/utils/lb/lb.go
@@ -0,0 +1,687 @@
+package lb
+
+import (
+ "crypto/md5"
+ "log"
+ "net"
+ "sync"
+
+ "github.com/snail007/goproxy/utils/dnsx"
+)
+
+const (
+ SELECT_ROUNDROBIN = iota
+ SELECT_LEASTCONN
+ SELECT_HASH
+ SELECT_WEITHT
+ SELECT_LEASTTIME
+)
+
+type Selector interface {
+ Select(srcAddr string) (addr string)
+ SelectBackend(srcAddr string) (b *Backend)
+ IncreasConns(addr string)
+ DecreaseConns(addr string)
+ Stop()
+ Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger)
+ IsActive() bool
+ ActiveCount() (count int)
+ Backends() (bs []*Backend)
+}
+
+type Group struct {
+ selector *Selector
+ log *log.Logger
+ dr *dnsx.DomainResolver
+ lock *sync.Mutex
+ last *Backend
+ debug bool
+}
+
+func NewGroup(selectType int, configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger, debug bool) Group {
+ bks := []*Backend{}
+ for _, c := range configs {
+ b, _ := NewBackend(*c, dr, log)
+ bks = append(bks, b)
+ }
+ if len(bks) > 1 {
+ for _, b := range bks {
+ b.StartHeartCheck()
+ }
+ }
+ var s Selector
+ switch selectType {
+ case SELECT_ROUNDROBIN:
+ s = NewRoundRobin(bks, log, debug)
+ case SELECT_LEASTCONN:
+ s = NewLeastConn(bks, log, debug)
+ case SELECT_HASH:
+ s = NewHash(bks, log, debug)
+ case SELECT_WEITHT:
+ s = NewWeight(bks, log, debug)
+ case SELECT_LEASTTIME:
+ s = NewLeastTime(bks, log, debug)
+ }
+ return Group{
+ selector: &s,
+ log: log,
+ dr: dr,
+ lock: &sync.Mutex{},
+ debug: debug,
+ }
+}
+func (g *Group) Select(srcAddr string, onlyHa bool) (addr string) {
+ if onlyHa {
+ g.lock.Lock()
+ defer g.lock.Unlock()
+ if g.last != nil && (g.last.Active || g.last.ConnectUsedMillisecond == 0) {
+ if g.debug {
+ g.log.Printf("############ choosed %s from lastest ############", g.last.Address)
+ printDebug(true, g.log, nil, srcAddr, (*g.selector).Backends())
+ }
+ return g.last.Address
+ }
+ g.last = (*g.selector).SelectBackend(srcAddr)
+ if !g.last.Active && g.last.ConnectUsedMillisecond > 0 {
+ g.log.Printf("###warn### lb selected empty , return default , for : %s", srcAddr)
+ }
+ return g.last.Address
+ }
+ b := (*g.selector).SelectBackend(srcAddr)
+ return b.Address
+
+}
+func (g *Group) IncreasConns(addr string) {
+ (*g.selector).IncreasConns(addr)
+}
+func (g *Group) DecreaseConns(addr string) {
+ (*g.selector).DecreaseConns(addr)
+}
+func (g *Group) Stop() {
+ if g.selector != nil {
+ (*g.selector).Stop()
+ }
+}
+func (g *Group) IsActive() bool {
+ return (*g.selector).IsActive()
+}
+func (g *Group) ActiveCount() (count int) {
+ return (*g.selector).ActiveCount()
+}
+func (g *Group) Reset(addrs []string) {
+ bks := (*g.selector).Backends()
+ if len(bks) == 0 {
+ return
+ }
+ cfg := bks[0].BackendConfig
+ configs := BackendsConfig{}
+ for _, addr := range addrs {
+ c := cfg
+ c.Address = addr
+ configs = append(configs, &c)
+ }
+ (*g.selector).Reset(configs, g.dr, g.log)
+}
+func (g *Group) Backends() []*Backend {
+ return (*g.selector).Backends()
+}
+
+//########################RoundRobin##########################
+type RoundRobin struct {
+ sync.Mutex
+ backendIndex int
+ backends Backends
+ log *log.Logger
+ debug bool
+}
+
+func NewRoundRobin(backends Backends, log *log.Logger, debug bool) Selector {
+ return &RoundRobin{
+ backends: backends,
+ log: log,
+ debug: debug,
+ }
+
+}
+func (r *RoundRobin) Select(srcAddr string) (addr string) {
+ return r.SelectBackend(srcAddr).Address
+}
+func (r *RoundRobin) SelectBackend(srcAddr string) (b *Backend) {
+ r.Lock()
+ defer r.Unlock()
+ defer func() {
+ printDebug(r.debug, r.log, b, srcAddr, r.backends)
+ }()
+ if len(r.backends) == 0 {
+ return
+ }
+ if len(r.backends) == 1 {
+ return r.backends[0]
+ }
+RETRY:
+ found := false
+ for _, b := range r.backends {
+ if b.Active {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return r.backends[0]
+ }
+ r.backendIndex++
+ if r.backendIndex > len(r.backends)-1 {
+ r.backendIndex = 0
+ }
+ if !r.backends[r.backendIndex].Active {
+ goto RETRY
+ }
+ return r.backends[r.backendIndex]
+}
+func (r *RoundRobin) IncreasConns(addr string) {
+
+}
+func (r *RoundRobin) DecreaseConns(addr string) {
+
+}
+func (r *RoundRobin) Stop() {
+ for _, b := range r.backends {
+ b.StopHeartCheck()
+ }
+}
+func (r *RoundRobin) Backends() []*Backend {
+ return r.backends
+}
+func (r *RoundRobin) IsActive() bool {
+ for _, b := range r.backends {
+ if b.Active {
+ return true
+ }
+ }
+ return false
+}
+func (r *RoundRobin) ActiveCount() (count int) {
+ for _, b := range r.backends {
+ if b.Active {
+ count++
+ }
+ }
+ return
+}
+func (r *RoundRobin) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
+ r.Lock()
+ defer r.Unlock()
+ r.Stop()
+ bks := []*Backend{}
+ for _, c := range configs {
+ b, _ := NewBackend(*c, dr, log)
+ bks = append(bks, b)
+ }
+ if len(bks) > 1 {
+ for _, b := range bks {
+ b.StartHeartCheck()
+ }
+ }
+ r.backends = bks
+}
+
+//########################LeastConn##########################
+
+type LeastConn struct {
+ sync.Mutex
+ backends Backends
+ log *log.Logger
+ debug bool
+}
+
+func NewLeastConn(backends []*Backend, log *log.Logger, debug bool) Selector {
+ lc := LeastConn{
+ backends: backends,
+ log: log,
+ debug: debug,
+ }
+ return &lc
+}
+
+func (lc *LeastConn) Select(srcAddr string) (addr string) {
+ return lc.SelectBackend(srcAddr).Address
+}
+func (lc *LeastConn) SelectBackend(srcAddr string) (b *Backend) {
+ lc.Lock()
+ defer lc.Unlock()
+ defer func() {
+ printDebug(lc.debug, lc.log, b, srcAddr, lc.backends)
+ }()
+ if len(lc.backends) == 0 {
+ return
+ }
+ if len(lc.backends) == 1 {
+ return lc.backends[0]
+ }
+ found := false
+ for _, b := range lc.backends {
+ if b.Active {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return lc.backends[0]
+ }
+ min := lc.backends[0].Connections
+ index := 0
+ for i, b := range lc.backends {
+ if b.Active {
+ min = b.Connections
+ index = i
+ break
+ }
+ }
+ for i, b := range lc.backends {
+ if b.Active && b.Connections <= min {
+ min = b.Connections
+ index = i
+ }
+ }
+ return lc.backends[index]
+}
+func (lc *LeastConn) IncreasConns(addr string) {
+ for _, a := range lc.backends {
+ if a.Address == addr {
+ a.IncreasConns()
+ return
+ }
+ }
+}
+func (lc *LeastConn) DecreaseConns(addr string) {
+ for _, a := range lc.backends {
+ if a.Address == addr {
+ a.DecreaseConns()
+ return
+ }
+ }
+}
+func (lc *LeastConn) Stop() {
+ for _, b := range lc.backends {
+ b.StopHeartCheck()
+ }
+}
+func (lc *LeastConn) IsActive() bool {
+ for _, b := range lc.backends {
+ if b.Active {
+ return true
+ }
+ }
+ return false
+}
+func (lc *LeastConn) ActiveCount() (count int) {
+ for _, b := range lc.backends {
+ if b.Active {
+ count++
+ }
+ }
+ return
+}
+func (lc *LeastConn) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
+ lc.Lock()
+ defer lc.Unlock()
+ lc.Stop()
+ bks := []*Backend{}
+ for _, c := range configs {
+ b, _ := NewBackend(*c, dr, log)
+ bks = append(bks, b)
+ }
+ if len(bks) > 1 {
+ for _, b := range bks {
+ b.StartHeartCheck()
+ }
+ }
+ lc.backends = bks
+}
+func (lc *LeastConn) Backends() []*Backend {
+ return lc.backends
+}
+
+//########################Hash##########################
+type Hash struct {
+ sync.Mutex
+ backends Backends
+ log *log.Logger
+ debug bool
+}
+
+func NewHash(backends Backends, log *log.Logger, debug bool) Selector {
+ return &Hash{
+ backends: backends,
+ log: log,
+ debug: debug,
+ }
+}
+func (h *Hash) Select(srcAddr string) (addr string) {
+ return h.SelectBackend(srcAddr).Address
+}
+func (h *Hash) SelectBackend(srcAddr string) (b *Backend) {
+ h.Lock()
+ defer h.Unlock()
+ defer func() {
+ printDebug(h.debug, h.log, b, srcAddr, h.backends)
+ }()
+ if len(h.backends) == 0 {
+ return
+ }
+ if len(h.backends) == 1 {
+ return h.backends[0]
+ }
+ i := 0
+ host, _, err := net.SplitHostPort(srcAddr)
+ if err != nil {
+ return
+ }
+ //porti, _ := strconv.Atoi(port)
+ //i += porti
+ for _, b := range md5.Sum([]byte(host)) {
+ i += int(b)
+ }
+RETRY:
+ found := false
+ for _, b := range h.backends {
+ if b.Active {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return h.backends[0]
+ }
+ k := i % len(h.backends)
+ if !h.backends[k].Active {
+ i++
+ goto RETRY
+ }
+ return h.backends[k]
+}
+func (h *Hash) IncreasConns(addr string) {
+
+}
+func (h *Hash) DecreaseConns(addr string) {
+
+}
+func (h *Hash) Stop() {
+ for _, b := range h.backends {
+ b.StopHeartCheck()
+ }
+}
+func (h *Hash) IsActive() bool {
+ for _, b := range h.backends {
+ if b.Active {
+ return true
+ }
+ }
+ return false
+}
+func (h *Hash) ActiveCount() (count int) {
+ for _, b := range h.backends {
+ if b.Active {
+ count++
+ }
+ }
+ return
+}
+func (h *Hash) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
+ h.Lock()
+ defer h.Unlock()
+ h.Stop()
+ bks := []*Backend{}
+ for _, c := range configs {
+ b, _ := NewBackend(*c, dr, log)
+ bks = append(bks, b)
+ }
+ if len(bks) > 1 {
+ for _, b := range bks {
+ b.StartHeartCheck()
+ }
+ }
+ h.backends = bks
+}
+func (h *Hash) Backends() []*Backend {
+ return h.backends
+}
+
+//########################Weight##########################
+type Weight struct {
+ sync.Mutex
+ backends Backends
+ log *log.Logger
+ debug bool
+}
+
+func NewWeight(backends Backends, log *log.Logger, debug bool) Selector {
+ return &Weight{
+ backends: backends,
+ log: log,
+ debug: debug,
+ }
+}
+func (w *Weight) Select(srcAddr string) (addr string) {
+ return w.SelectBackend(srcAddr).Address
+}
+func (w *Weight) SelectBackend(srcAddr string) (b *Backend) {
+ w.Lock()
+ defer w.Unlock()
+ defer func() {
+ printDebug(w.debug, w.log, b, srcAddr, w.backends)
+ }()
+ if len(w.backends) == 0 {
+ return
+ }
+ if len(w.backends) == 1 {
+ return w.backends[0]
+ }
+
+ found := false
+ for _, b := range w.backends {
+ if b.Active {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return w.backends[0]
+ }
+
+ min := w.backends[0].Connections / w.backends[0].Weight
+ index := 0
+ for i, b := range w.backends {
+ if b.Active {
+ min = b.Connections / b.Weight
+ index = i
+ break
+ }
+ }
+ for i, b := range w.backends {
+ if b.Active && b.Connections/b.Weight <= min {
+ min = b.Connections
+ index = i
+ }
+ }
+ return w.backends[index]
+}
+func (w *Weight) IncreasConns(addr string) {
+ w.Lock()
+ defer w.Unlock()
+ for _, a := range w.backends {
+ if a.Address == addr {
+ a.IncreasConns()
+ return
+ }
+ }
+}
+func (w *Weight) DecreaseConns(addr string) {
+ w.Lock()
+ defer w.Unlock()
+ for _, a := range w.backends {
+ if a.Address == addr {
+ a.DecreaseConns()
+ return
+ }
+ }
+}
+func (w *Weight) Stop() {
+ for _, b := range w.backends {
+ b.StopHeartCheck()
+ }
+}
+func (w *Weight) IsActive() bool {
+ for _, b := range w.backends {
+ if b.Active {
+ return true
+ }
+ }
+ return false
+}
+func (w *Weight) ActiveCount() (count int) {
+ for _, b := range w.backends {
+ if b.Active {
+ count++
+ }
+ }
+ return
+}
+func (w *Weight) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
+ w.Lock()
+ defer w.Unlock()
+ w.Stop()
+ bks := []*Backend{}
+ for _, c := range configs {
+ b, _ := NewBackend(*c, dr, log)
+ bks = append(bks, b)
+ }
+ if len(bks) > 1 {
+ for _, b := range bks {
+ b.StartHeartCheck()
+ }
+ }
+ w.backends = bks
+}
+func (w *Weight) Backends() []*Backend {
+ return w.backends
+}
+
+//########################LeastTime##########################
+
+type LeastTime struct {
+ sync.Mutex
+ backends Backends
+ log *log.Logger
+ debug bool
+}
+
+func NewLeastTime(backends []*Backend, log *log.Logger, debug bool) Selector {
+ lt := LeastTime{
+ backends: backends,
+ log: log,
+ debug: debug,
+ }
+ return <
+}
+
+func (lt *LeastTime) Select(srcAddr string) (addr string) {
+ return lt.SelectBackend(srcAddr).Address
+}
+func (lt *LeastTime) SelectBackend(srcAddr string) (b *Backend) {
+ lt.Lock()
+ defer lt.Unlock()
+ defer func() {
+ printDebug(lt.debug, lt.log, b, srcAddr, lt.backends)
+ }()
+ if len(lt.backends) == 0 {
+ return
+ }
+ if len(lt.backends) == 1 {
+ return lt.backends[0]
+ }
+ found := false
+ for _, b := range lt.backends {
+ if b.Active {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return lt.backends[0]
+ }
+ min := lt.backends[0].ConnectUsedMillisecond
+ index := 0
+ for i, b := range lt.backends {
+ if b.Active {
+ min = b.ConnectUsedMillisecond
+ index = i
+ break
+ }
+ }
+ for i, b := range lt.backends {
+ if b.Active && b.ConnectUsedMillisecond > 0 && b.ConnectUsedMillisecond <= min {
+ min = b.ConnectUsedMillisecond
+ index = i
+ }
+ }
+ return lt.backends[index]
+}
+func (lt *LeastTime) IncreasConns(addr string) {
+
+}
+func (lt *LeastTime) DecreaseConns(addr string) {
+
+}
+func (lt *LeastTime) Stop() {
+ for _, b := range lt.backends {
+ b.StopHeartCheck()
+ }
+}
+func (lt *LeastTime) IsActive() bool {
+ for _, b := range lt.backends {
+ if b.Active {
+ return true
+ }
+ }
+ return false
+}
+func (lt *LeastTime) ActiveCount() (count int) {
+ for _, b := range lt.backends {
+ if b.Active {
+ count++
+ }
+ }
+ return
+}
+func (lt *LeastTime) Reset(configs BackendsConfig, dr *dnsx.DomainResolver, log *log.Logger) {
+ lt.Lock()
+ defer lt.Unlock()
+ lt.Stop()
+ bks := []*Backend{}
+ for _, c := range configs {
+ b, _ := NewBackend(*c, dr, log)
+ bks = append(bks, b)
+ }
+ if len(bks) > 1 {
+ for _, b := range bks {
+ b.StartHeartCheck()
+ }
+ }
+ lt.backends = bks
+}
+func (lt *LeastTime) Backends() []*Backend {
+ return lt.backends
+}
+func printDebug(isDebug bool, log *log.Logger, selected *Backend, srcAddr string, backends []*Backend) {
+ if isDebug {
+ log.Printf("############ LB start ############\n")
+ if selected != nil {
+ log.Printf("choosed %s for %s\n", selected.Address, srcAddr)
+ }
+ for _, v := range backends {
+ log.Printf("addr:%s,conns:%d,time:%d,weight:%d,active:%v\n", v.Address, v.Connections, v.ConnectUsedMillisecond, v.Weight, v.Active)
+ }
+ log.Printf("############ LB end ############\n")
+ }
+}
diff --git a/utils/mapx/map.go b/utils/mapx/map.go
new file mode 100644
index 0000000..651a818
--- /dev/null
+++ b/utils/mapx/map.go
@@ -0,0 +1,355 @@
+package mapx
+
+import (
+ "encoding/json"
+ "fmt"
+ "runtime/debug"
+ "sync"
+)
+
+var SHARD_COUNT = 32
+
+// A "thread" safe map of type string:Anything.
+// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
+type ConcurrentMap []*ConcurrentMapShared
+
+// A "thread" safe string to anything map.
+type ConcurrentMapShared struct {
+ items map[string]interface{}
+ sync.RWMutex // Read Write mutex, guards access to internal map.
+}
+
+// Creates a new concurrent map.
+func NewConcurrentMap() ConcurrentMap {
+ m := make(ConcurrentMap, SHARD_COUNT)
+ for i := 0; i < SHARD_COUNT; i++ {
+ m[i] = &ConcurrentMapShared{items: make(map[string]interface{})}
+ }
+ return m
+}
+
+// Returns shard under given key
+func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared {
+ return m[uint(fnv32(key))%uint(SHARD_COUNT)]
+}
+
+func (m ConcurrentMap) MSet(data map[string]interface{}) {
+ for key, value := range data {
+ shard := m.GetShard(key)
+ shard.Lock()
+ shard.items[key] = value
+ shard.Unlock()
+ }
+}
+
+// Sets the given value under the specified key.
+func (m ConcurrentMap) Set(key string, value interface{}) {
+ // Get map shard.
+ shard := m.GetShard(key)
+ shard.Lock()
+ shard.items[key] = value
+ shard.Unlock()
+}
+
+// Callback to return new element to be inserted into the map
+// It is called while lock is held, therefore it MUST NOT
+// try to access other keys in same map, as it can lead to deadlock since
+// Go sync.RWLock is not reentrant
+type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{}
+
+// Insert or Update - updates existing element or inserts a new one using UpsertCb
+func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) {
+ shard := m.GetShard(key)
+ shard.Lock()
+ v, ok := shard.items[key]
+ res = cb(ok, v, value)
+ shard.items[key] = res
+ shard.Unlock()
+ return res
+}
+
+// Sets the given value under the specified key if no value was associated with it.
+func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
+ // Get map shard.
+ shard := m.GetShard(key)
+ shard.Lock()
+ _, ok := shard.items[key]
+ if !ok {
+ shard.items[key] = value
+ }
+ shard.Unlock()
+ return !ok
+}
+
+// Retrieves an element from map under given key.
+func (m ConcurrentMap) Get(key string) (interface{}, bool) {
+ // Get shard
+ shard := m.GetShard(key)
+ shard.RLock()
+ // Get item from shard.
+ val, ok := shard.items[key]
+ shard.RUnlock()
+ return val, ok
+}
+
+// Returns the number of elements within the map.
+func (m ConcurrentMap) Count() int {
+ count := 0
+ for i := 0; i < SHARD_COUNT; i++ {
+ shard := m[i]
+ shard.RLock()
+ count += len(shard.items)
+ shard.RUnlock()
+ }
+ return count
+}
+
+// Looks up an item under specified key
+func (m ConcurrentMap) Has(key string) bool {
+ // Get shard
+ shard := m.GetShard(key)
+ shard.RLock()
+ // See if element is within shard.
+ _, ok := shard.items[key]
+ shard.RUnlock()
+ return ok
+}
+
+// Removes an element from the map.
+func (m ConcurrentMap) Remove(key string) {
+ // Try to get shard.
+ shard := m.GetShard(key)
+ shard.Lock()
+ delete(shard.items, key)
+ shard.Unlock()
+}
+
+// Removes an element from the map and returns it
+func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
+ // Try to get shard.
+ shard := m.GetShard(key)
+ shard.Lock()
+ v, exists = shard.items[key]
+ delete(shard.items, key)
+ shard.Unlock()
+ return v, exists
+}
+
+// Checks if map is empty.
+func (m ConcurrentMap) IsEmpty() bool {
+ return m.Count() == 0
+}
+
+// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
+type Tuple struct {
+ Key string
+ Val interface{}
+}
+
+// Returns an iterator which could be used in a for range loop.
+//
+// Deprecated: using IterBuffered() will get a better performence
+func (m ConcurrentMap) Iter() <-chan Tuple {
+ chans := snapshot(m)
+ ch := make(chan Tuple)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ fanIn(chans, ch)
+ }()
+ return ch
+}
+
+// Returns a buffered iterator which could be used in a for range loop.
+func (m ConcurrentMap) IterBuffered() <-chan Tuple {
+ chans := snapshot(m)
+ total := 0
+ for _, c := range chans {
+ total += cap(c)
+ }
+ ch := make(chan Tuple, total)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ fanIn(chans, ch)
+ }()
+ return ch
+}
+
+// Returns a array of channels that contains elements in each shard,
+// which likely takes a snapshot of `m`.
+// It returns once the size of each buffered channel is determined,
+// before all the channels are populated using goroutines.
+func snapshot(m ConcurrentMap) (chans []chan Tuple) {
+ chans = make([]chan Tuple, SHARD_COUNT)
+ wg := sync.WaitGroup{}
+ wg.Add(SHARD_COUNT)
+ // Foreach shard.
+ for index, shard := range m {
+ go func(index int, shard *ConcurrentMapShared) {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ // Foreach key, value pair.
+ shard.RLock()
+ chans[index] = make(chan Tuple, len(shard.items))
+ wg.Done()
+ for key, val := range shard.items {
+ chans[index] <- Tuple{key, val}
+ }
+ shard.RUnlock()
+ close(chans[index])
+ }(index, shard)
+ }
+ wg.Wait()
+ return chans
+}
+
+// fanIn reads elements from channels `chans` into channel `out`
+func fanIn(chans []chan Tuple, out chan Tuple) {
+ wg := sync.WaitGroup{}
+ wg.Add(len(chans))
+ for _, ch := range chans {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ func(ch chan Tuple) {
+ for t := range ch {
+ out <- t
+ }
+ wg.Done()
+ }(ch)
+ }()
+ }
+ wg.Wait()
+ close(out)
+}
+
+// Returns all items as map[string]interface{}
+func (m ConcurrentMap) Items() map[string]interface{} {
+ tmp := make(map[string]interface{})
+
+ // Insert items to temporary map.
+ for item := range m.IterBuffered() {
+ tmp[item.Key] = item.Val
+ }
+
+ return tmp
+}
+
+// Iterator callback,called for every key,value found in
+// maps. RLock is held for all calls for a given shard
+// therefore callback sess consistent view of a shard,
+// but not across the shards
+type IterCb func(key string, v interface{})
+
+// Callback based iterator, cheapest way to read
+// all elements in a map.
+func (m ConcurrentMap) IterCb(fn IterCb) {
+ for idx := range m {
+ shard := (m)[idx]
+ shard.RLock()
+ for key, value := range shard.items {
+ fn(key, value)
+ }
+ shard.RUnlock()
+ }
+}
+
+// Return all keys as []string
+func (m ConcurrentMap) Keys() []string {
+ count := m.Count()
+ ch := make(chan string, count)
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ // Foreach shard.
+ wg := sync.WaitGroup{}
+ wg.Add(SHARD_COUNT)
+ for _, shard := range m {
+ go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
+ func(shard *ConcurrentMapShared) {
+ // Foreach key, value pair.
+ shard.RLock()
+ for key := range shard.items {
+ ch <- key
+ }
+ shard.RUnlock()
+ wg.Done()
+ }(shard)
+ }()
+ }
+ wg.Wait()
+ close(ch)
+ }()
+
+ // Generate keys
+ keys := make([]string, 0, count)
+ for k := range ch {
+ keys = append(keys, k)
+ }
+ return keys
+}
+
+//Reviles ConcurrentMap "private" variables to json marshal.
+func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
+ // Create a temporary map, which will hold all item spread across shards.
+ tmp := make(map[string]interface{})
+
+ // Insert items to temporary map.
+ for item := range m.IterBuffered() {
+ tmp[item.Key] = item.Val
+ }
+ return json.Marshal(tmp)
+}
+
+func fnv32(key string) uint32 {
+ hash := uint32(2166136261)
+ const prime32 = uint32(16777619)
+ for i := 0; i < len(key); i++ {
+ hash *= prime32
+ hash ^= uint32(key[i])
+ }
+ return hash
+}
+
+// Concurrent map uses Interface{} as its value, therefor JSON Unmarshal
+// will probably won't know which to type to unmarshal into, in such case
+// we'll end up with a value of type map[string]interface{}, In most cases this isn't
+// out value type, this is why we've decided to remove this functionality.
+
+// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) {
+// // Reverse process of Marshal.
+
+// tmp := make(map[string]interface{})
+
+// // Unmarshal into a single map.
+// if err := json.Unmarshal(b, &tmp); err != nil {
+// return nil
+// }
+
+// // foreach key,value pair in temporary map insert into our concurrent map.
+// for key, val := range tmp {
+// m.Set(key, val)
+// }
+// return nil
+// }
diff --git a/utils/serve-channel.go b/utils/serve-channel.go
index 80d6fcc..621747f 100644
--- a/utils/serve-channel.go
+++ b/utils/serve-channel.go
@@ -139,7 +139,7 @@ func (sc *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) {
}
return
}
-func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) {
+func (sc *ServerChannel) ListenUDP(fn func(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) {
addr := &net.UDPAddr{IP: net.ParseIP(sc.ip), Port: sc.port}
l, err := net.ListenUDP("udp", addr)
if err == nil {
@@ -161,7 +161,7 @@ func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *ne
sc.log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
- fn(packet, addr, srcAddr)
+ fn(l, packet, addr, srcAddr)
}()
} else {
sc.errAcceptHandler(err)
@@ -172,6 +172,7 @@ func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *ne
}
return
}
+
func (sc *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net.Conn), log *logger.Logger) (err error) {
lis, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", sc.ip, sc.port), config.Block, *config.DataShard, *config.ParityShard)
if err == nil {
diff --git a/utils/socks/client.go b/utils/socks/client.go
index 2460b3b..b173b3a 100644
--- a/utils/socks/client.go
+++ b/utils/socks/client.go
@@ -54,7 +54,7 @@ func NewClientConn(conn *net.Conn, network, target string, timeout time.Duration
s.header = header
}
if network == "udp" && target == "" {
- target = "0.0.0.0:1"
+ target = "0.0.0.0:0"
}
s.addr = target
return s
diff --git a/utils/socks/server.go b/utils/socks/server.go
index 8808562..6f72f50 100644
--- a/utils/socks/server.go
+++ b/utils/socks/server.go
@@ -98,6 +98,12 @@ func (s *ServerConn) Method() uint8 {
func (s *ServerConn) Target() string {
return s.target
}
+func (s *ServerConn) Host() string {
+ return s.dstHost
+}
+func (s *ServerConn) Port() string {
+ return s.dstPort
+}
func (s *ServerConn) Handshake() (err error) {
remoteAddr := (*s.conn).RemoteAddr()
//协商开始
diff --git a/utils/socks/structs.go b/utils/socks/structs.go
index a5ec922..1d90df8 100644
--- a/utils/socks/structs.go
+++ b/utils/socks/structs.go
@@ -281,6 +281,9 @@ func (p *PacketUDP) Build(destAddr string, data []byte) (err error) {
return
}
func (p *PacketUDP) Parse(b []byte) (err error) {
+ if len(b) < 9 {
+ return fmt.Errorf("too short packet")
+ }
p.frag = uint8(b[2])
if p.frag != 0 {
err = fmt.Errorf("FRAG only support for 0 , %v ,%v", p.frag, b[:4])
@@ -290,13 +293,22 @@ func (p *PacketUDP) Parse(b []byte) (err error) {
p.atype = b[3]
switch p.atype {
case ATYP_IPV4: //IP V4
+ if len(b) < 11 {
+ return fmt.Errorf("too short packet")
+ }
p.dstHost = net.IPv4(b[4], b[5], b[6], b[7]).String()
portIndex = 8
case ATYP_DOMAIN: //域名
domainLen := uint8(b[4])
+ if len(b) < int(domainLen)+7 {
+ return fmt.Errorf("too short packet")
+ }
p.dstHost = string(b[5 : 5+domainLen]) //b[4]表示域名的长度
portIndex = int(5 + domainLen)
case ATYP_IPV6: //IP V6
+ if len(b) < 22 {
+ return fmt.Errorf("too short packet")
+ }
p.dstHost = net.IP{b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15], b[16], b[17], b[18], b[19]}.String()
portIndex = 20
}
@@ -333,7 +345,9 @@ func (p *PacketUDP) Bytes() []byte {
func (p *PacketUDP) Host() string {
return p.dstHost
}
-
+func (p *PacketUDP) Addr() string {
+ return net.JoinHostPort(p.dstHost, p.dstPort)
+}
func (p *PacketUDP) Port() string {
return p.dstPort
}
diff --git a/utils/ss/conn.go b/utils/ss/conn.go
new file mode 100644
index 0000000..9d2ddc4
--- /dev/null
+++ b/utils/ss/conn.go
@@ -0,0 +1,180 @@
+package ss
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+)
+
+const (
+ OneTimeAuthMask byte = 0x10
+ AddrMask byte = 0xf
+)
+
+type Conn struct {
+ net.Conn
+ *Cipher
+ readBuf []byte
+ writeBuf []byte
+ chunkId uint32
+}
+
+func NewConn(c net.Conn, cipher *Cipher) *Conn {
+ return &Conn{
+ Conn: c,
+ Cipher: cipher,
+ readBuf: leakyBuf.Get(),
+ writeBuf: leakyBuf.Get()}
+}
+
+func (c *Conn) Close() error {
+ leakyBuf.Put(c.readBuf)
+ leakyBuf.Put(c.writeBuf)
+ return c.Conn.Close()
+}
+
+func RawAddr(addr string) (buf []byte, err error) {
+ host, portStr, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, fmt.Errorf("ss: address error %s %v", addr, err)
+ }
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return nil, fmt.Errorf("ss: invalid port %s", addr)
+ }
+
+ hostLen := len(host)
+ l := 1 + 1 + hostLen + 2 // addrType + lenByte + address + port
+ buf = make([]byte, l)
+ buf[0] = 3 // 3 means the address is domain name
+ buf[1] = byte(hostLen) // host address length followed by host address
+ copy(buf[2:], host)
+ binary.BigEndian.PutUint16(buf[2+hostLen:2+hostLen+2], uint16(port))
+ return
+}
+
+// This is intended for use by users implementing a local socks proxy.
+// rawaddr shoud contain part of the data in socks request, starting from the
+// ATYP field. (Refer to rfc1928 for more information.)
+func DialWithRawAddr(rawConn *net.Conn, rawaddr []byte, server string, cipher *Cipher) (c *Conn, err error) {
+ var conn net.Conn
+ if rawConn == nil {
+ conn, err = net.Dial("tcp", server)
+ }
+ if err != nil {
+ return
+ }
+ if rawConn != nil {
+ c = NewConn(*rawConn, cipher)
+ } else {
+ c = NewConn(conn, cipher)
+ }
+
+ if _, err = c.write(rawaddr); err != nil {
+ c.Close()
+ return nil, err
+ }
+ return
+}
+
+// addr should be in the form of host:port
+func Dial(addr, server string, cipher *Cipher) (c *Conn, err error) {
+ ra, err := RawAddr(addr)
+ if err != nil {
+ return
+ }
+ return DialWithRawAddr(nil, ra, server, cipher)
+}
+
+func (c *Conn) GetIv() (iv []byte) {
+ iv = make([]byte, len(c.iv))
+ copy(iv, c.iv)
+ return
+}
+
+func (c *Conn) GetKey() (key []byte) {
+ key = make([]byte, len(c.key))
+ copy(key, c.key)
+ return
+}
+
+func (c *Conn) IsOta() bool {
+ return c.ota
+}
+
+func (c *Conn) GetAndIncrChunkId() (chunkId uint32) {
+ chunkId = c.chunkId
+ c.chunkId += 1
+ return
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+ if c.dec == nil {
+ iv := make([]byte, c.info.ivLen)
+ if _, err = io.ReadFull(c.Conn, iv); err != nil {
+ return
+ }
+ if err = c.initDecrypt(iv); err != nil {
+ return
+ }
+ if len(c.iv) == 0 {
+ c.iv = iv
+ }
+ }
+
+ cipherData := c.readBuf
+ if len(b) > len(cipherData) {
+ cipherData = make([]byte, len(b))
+ } else {
+ cipherData = cipherData[:len(b)]
+ }
+
+ n, err = c.Conn.Read(cipherData)
+ if n > 0 {
+ c.decrypt(b[0:n], cipherData[0:n])
+ }
+ return
+}
+
+func (c *Conn) Write(b []byte) (n int, err error) {
+ nn := len(b)
+
+ headerLen := len(b) - nn
+
+ n, err = c.write(b)
+ // Make sure <= 0 <= len(b), where b is the slice passed in.
+ if n >= headerLen {
+ n -= headerLen
+ }
+ return
+}
+
+func (c *Conn) write(b []byte) (n int, err error) {
+ var iv []byte
+ if c.enc == nil {
+ iv, err = c.initEncrypt()
+ if err != nil {
+ return
+ }
+ }
+
+ cipherData := c.writeBuf
+ dataSize := len(b) + len(iv)
+ if dataSize > len(cipherData) {
+ cipherData = make([]byte, dataSize)
+ } else {
+ cipherData = cipherData[:dataSize]
+ }
+
+ if iv != nil {
+ // Put initialization vector in buffer, do a single write to send both
+ // iv and data.
+ copy(cipherData, iv)
+ }
+
+ c.encrypt(cipherData[len(iv):], b)
+ n, err = c.Conn.Write(cipherData)
+ return
+}
diff --git a/utils/ss/encrypt.go b/utils/ss/encrypt.go
new file mode 100644
index 0000000..3a40d64
--- /dev/null
+++ b/utils/ss/encrypt.go
@@ -0,0 +1,301 @@
+package ss
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/des"
+ "crypto/md5"
+ "crypto/rand"
+ "crypto/rc4"
+ "encoding/binary"
+ "errors"
+ "io"
+ "strings"
+
+ "github.com/Yawning/chacha20"
+ "golang.org/x/crypto/blowfish"
+ "golang.org/x/crypto/cast5"
+ "golang.org/x/crypto/salsa20/salsa"
+)
+
+var errEmptyPassword = errors.New("empty key")
+
+func md5sum(d []byte) []byte {
+ h := md5.New()
+ h.Write(d)
+ return h.Sum(nil)
+}
+
+func evpBytesToKey(password string, keyLen int) (key []byte) {
+ const md5Len = 16
+
+ cnt := (keyLen-1)/md5Len + 1
+ m := make([]byte, cnt*md5Len)
+ copy(m, md5sum([]byte(password)))
+
+ // Repeatedly call md5 until bytes generated is enough.
+ // Each call to md5 uses data: prev md5 sum + password.
+ d := make([]byte, md5Len+len(password))
+ start := 0
+ for i := 1; i < cnt; i++ {
+ start += md5Len
+ copy(d, m[start-md5Len:start])
+ copy(d[md5Len:], password)
+ copy(m[start:], md5sum(d))
+ }
+ return m[:keyLen]
+}
+
+type DecOrEnc int
+
+const (
+ Decrypt DecOrEnc = iota
+ Encrypt
+)
+
+func newStream(block cipher.Block, err error, key, iv []byte,
+ doe DecOrEnc) (cipher.Stream, error) {
+ if err != nil {
+ return nil, err
+ }
+ if doe == Encrypt {
+ return cipher.NewCFBEncrypter(block, iv), nil
+ } else {
+ return cipher.NewCFBDecrypter(block, iv), nil
+ }
+}
+
+func newAESCFBStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := aes.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newAESCTRStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+ return cipher.NewCTR(block, iv), nil
+}
+
+func newDESStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := des.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newBlowFishStream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := blowfish.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newCast5Stream(key, iv []byte, doe DecOrEnc) (cipher.Stream, error) {
+ block, err := cast5.NewCipher(key)
+ return newStream(block, err, key, iv, doe)
+}
+
+func newRC4MD5Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
+ h := md5.New()
+ h.Write(key)
+ h.Write(iv)
+ rc4key := h.Sum(nil)
+
+ return rc4.NewCipher(rc4key)
+}
+
+func newChaCha20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
+ return chacha20.NewCipher(key, iv)
+}
+
+func newChaCha20IETFStream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
+ return chacha20.NewCipher(key, iv)
+}
+
+type salsaStreamCipher struct {
+ nonce [8]byte
+ key [32]byte
+ counter int
+}
+
+func (c *salsaStreamCipher) XORKeyStream(dst, src []byte) {
+ var buf []byte
+ padLen := c.counter % 64
+ dataSize := len(src) + padLen
+ if cap(dst) >= dataSize {
+ buf = dst[:dataSize]
+ } else if leakyBufSize >= dataSize {
+ buf = leakyBuf.Get()
+ defer leakyBuf.Put(buf)
+ buf = buf[:dataSize]
+ } else {
+ buf = make([]byte, dataSize)
+ }
+
+ var subNonce [16]byte
+ copy(subNonce[:], c.nonce[:])
+ binary.LittleEndian.PutUint64(subNonce[len(c.nonce):], uint64(c.counter/64))
+
+ // It's difficult to avoid data copy here. src or dst maybe slice from
+ // Conn.Read/Write, which can't have padding.
+ copy(buf[padLen:], src[:])
+ salsa.XORKeyStream(buf, buf, &subNonce, &c.key)
+ copy(dst, buf[padLen:])
+
+ c.counter += len(src)
+}
+
+func newSalsa20Stream(key, iv []byte, _ DecOrEnc) (cipher.Stream, error) {
+ var c salsaStreamCipher
+ copy(c.nonce[:], iv[:8])
+ copy(c.key[:], key[:32])
+ return &c, nil
+}
+
+type cipherInfo struct {
+ keyLen int
+ ivLen int
+ newStream func(key, iv []byte, doe DecOrEnc) (cipher.Stream, error)
+}
+
+var cipherMethod = map[string]*cipherInfo{
+ "aes-128-cfb": {16, 16, newAESCFBStream},
+ "aes-192-cfb": {24, 16, newAESCFBStream},
+ "aes-256-cfb": {32, 16, newAESCFBStream},
+ "aes-128-ctr": {16, 16, newAESCTRStream},
+ "aes-192-ctr": {24, 16, newAESCTRStream},
+ "aes-256-ctr": {32, 16, newAESCTRStream},
+ "des-cfb": {8, 8, newDESStream},
+ "bf-cfb": {16, 8, newBlowFishStream},
+ "cast5-cfb": {16, 8, newCast5Stream},
+ "rc4-md5": {16, 16, newRC4MD5Stream},
+ "rc4-md5-6": {16, 6, newRC4MD5Stream},
+ "chacha20": {32, 8, newChaCha20Stream},
+ "chacha20-ietf": {32, 12, newChaCha20IETFStream},
+ "salsa20": {32, 8, newSalsa20Stream},
+}
+
+func CheckCipherMethod(method string) error {
+ if method == "" {
+ method = "aes-256-cfb"
+ }
+ _, ok := cipherMethod[method]
+ if !ok {
+ return errors.New("Unsupported encryption method: " + method)
+ }
+ return nil
+}
+
+type Cipher struct {
+ enc cipher.Stream
+ dec cipher.Stream
+ key []byte
+ info *cipherInfo
+ ota bool // one-time auth
+ iv []byte
+}
+
+// NewCipher creates a cipher that can be used in Dial() etc.
+// Use cipher.Copy() to create a new cipher with the same method and password
+// to avoid the cost of repeated cipher initialization.
+func NewCipher(method, password string) (c *Cipher, err error) {
+ if password == "" {
+ return nil, errEmptyPassword
+ }
+ var ota bool
+ if strings.HasSuffix(strings.ToLower(method), "-auth") {
+ method = method[:len(method)-5] // len("-auth") = 5
+ ota = true
+ } else {
+ ota = false
+ }
+ mi, ok := cipherMethod[method]
+ if !ok {
+ return nil, errors.New("Unsupported encryption method: " + method)
+ }
+
+ key := evpBytesToKey(password, mi.keyLen)
+
+ c = &Cipher{key: key, info: mi}
+
+ if err != nil {
+ return nil, err
+ }
+ c.ota = ota
+ return c, nil
+}
+
+// Initializes the block cipher with CFB mode, returns IV.
+func (c *Cipher) initEncrypt() (iv []byte, err error) {
+ if c.iv == nil {
+ iv = make([]byte, c.info.ivLen)
+ if _, err := io.ReadFull(rand.Reader, iv); err != nil {
+ return nil, err
+ }
+ c.iv = iv
+ } else {
+ iv = c.iv
+ }
+ c.enc, err = c.info.newStream(c.key, iv, Encrypt)
+ return
+}
+
+func (c *Cipher) initDecrypt(iv []byte) (err error) {
+ c.dec, err = c.info.newStream(c.key, iv, Decrypt)
+ return
+}
+
+func (c *Cipher) encrypt(dst, src []byte) {
+ c.enc.XORKeyStream(dst, src)
+}
+
+func (c *Cipher) decrypt(dst, src []byte) {
+ c.dec.XORKeyStream(dst, src)
+}
+func (c *Cipher) Encrypt(src []byte) (cipherData []byte) {
+ cipher := c.Copy()
+ iv, err := cipher.initEncrypt()
+ if err != nil {
+ return
+ }
+ packetLen := len(src) + len(iv)
+ cipherData = make([]byte, packetLen)
+ copy(cipherData, iv)
+ cipher.encrypt(cipherData[len(iv):], src)
+ return
+}
+
+func (c *Cipher) Decrypt(src []byte) (data []byte) {
+ cipher := c.Copy()
+ if len(src) < c.info.ivLen {
+ return
+ }
+ iv := make([]byte, c.info.ivLen)
+ copy(iv, src[:c.info.ivLen])
+ if err := cipher.initDecrypt(iv); err != nil {
+ return
+ }
+ data = make([]byte, len(src)-len(iv))
+ cipher.decrypt(data[0:], src[c.info.ivLen:])
+ return
+}
+
+// Copy creates a new cipher at it's initial state.
+func (c *Cipher) Copy() *Cipher {
+ // This optimization maybe not necessary. But without this function, we
+ // need to maintain a table cache for newTableCipher and use lock to
+ // protect concurrent access to that cache.
+
+ // AES and DES ciphers does not return specific types, so it's difficult
+ // to create copy. But their initizliation time is less than 4000ns on my
+ // 2.26 GHz Intel Core 2 Duo processor. So no need to worry.
+
+ // Currently, blow-fish and cast5 initialization cost is an order of
+ // maganitude slower than other ciphers. (I'm not sure whether this is
+ // because the current implementation is not highly optimized, or this is
+ // the nature of the algorithm.)
+
+ nc := *c
+ nc.enc = nil
+ nc.dec = nil
+ nc.ota = c.ota
+ return &nc
+}
diff --git a/utils/ss/util.go b/utils/ss/util.go
new file mode 100644
index 0000000..5a87120
--- /dev/null
+++ b/utils/ss/util.go
@@ -0,0 +1,82 @@
+package ss
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+
+ "github.com/snail007/goproxy/utils"
+)
+
+const leakyBufSize = 4108 // data.len(2) + hmacsha1(10) + data(4096)
+const maxNBuf = 2048
+
+var leakyBuf = utils.NewLeakyBuf(maxNBuf, leakyBufSize)
+
+const (
+ idType = 0 // address type index
+ idIP0 = 1 // ip addres start index
+ idDmLen = 1 // domain address length index
+ idDm0 = 2 // domain address start index
+
+ typeIPv4 = 1 // type is ipv4 address
+ typeDm = 3 // type is domain address
+ typeIPv6 = 4 // type is ipv6 address
+
+ lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
+ lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
+ lenDmBase = 2 // 1addrLen + 2port, plus addrLen
+ lenHmacSha1 = 10
+)
+
+func GetRequest(conn *Conn) (host string, err error) {
+
+ // buf size should at least have the same size with the largest possible
+ // request size (when addrType is 3, domain name has at most 256 bytes)
+ // 1(addrType) + 1(lenByte) + 255(max length address) + 2(port) + 10(hmac-sha1)
+ buf := make([]byte, 269)
+ // read till we get possible domain length field
+ if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
+ return
+ }
+
+ var reqStart, reqEnd int
+ addrType := buf[idType]
+ switch addrType & AddrMask {
+ case typeIPv4:
+ reqStart, reqEnd = idIP0, idIP0+lenIPv4
+ case typeIPv6:
+ reqStart, reqEnd = idIP0, idIP0+lenIPv6
+ case typeDm:
+ if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
+ return
+ }
+ reqStart, reqEnd = idDm0, idDm0+int(buf[idDmLen])+lenDmBase
+ default:
+ err = fmt.Errorf("addr type %d not supported", addrType&AddrMask)
+ return
+ }
+
+ if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
+ return
+ }
+
+ // Return string for typeIP is not most efficient, but browsers (Chrome,
+ // Safari, Firefox) all seems using typeDm exclusively. So this is not a
+ // big problem.
+ switch addrType & AddrMask {
+ case typeIPv4:
+ host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
+ case typeIPv6:
+ host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String()
+ case typeDm:
+ host = string(buf[idDm0 : idDm0+int(buf[idDmLen])])
+ }
+ // parse port
+ port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
+ host = net.JoinHostPort(host, strconv.Itoa(int(port)))
+
+ return
+}
diff --git a/utils/structs.go b/utils/structs.go
index 06a1ead..7d187f6 100644
--- a/utils/structs.go
+++ b/utils/structs.go
@@ -3,7 +3,6 @@ package utils
import (
"bufio"
"bytes"
- "crypto/tls"
"encoding/base64"
"errors"
"fmt"
@@ -12,21 +11,22 @@ import (
logger "log"
"net"
"net/url"
+ "runtime/debug"
"strings"
"sync"
"time"
- "github.com/snail007/goproxy/services/kcpcfg"
+ "github.com/snail007/goproxy/utils/dnsx"
+ "github.com/snail007/goproxy/utils/mapx"
"github.com/snail007/goproxy/utils/sni"
"github.com/golang/snappy"
- "github.com/miekg/dns"
)
type Checker struct {
- data ConcurrentMap
- blockedMap ConcurrentMap
- directMap ConcurrentMap
+ data mapx.ConcurrentMap
+ blockedMap mapx.ConcurrentMap
+ directMap mapx.ConcurrentMap
interval int64
timeout int
isStop bool
@@ -45,7 +45,7 @@ type CheckerItem struct {
//interval: recheck domain interval seconds
func NewChecker(timeout int, interval int64, blockedFile, directFile string, log *logger.Logger) Checker {
ch := Checker{
- data: NewConcurrentMap(),
+ data: mapx.NewConcurrentMap(),
interval: interval,
timeout: timeout,
isStop: false,
@@ -66,8 +66,8 @@ func NewChecker(timeout int, interval int64, blockedFile, directFile string, log
return ch
}
-func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) {
- dataMap = NewConcurrentMap()
+func (c *Checker) loadMap(f string) (dataMap mapx.ConcurrentMap) {
+ dataMap = mapx.NewConcurrentMap()
if PathExists(f) {
_contents, err := ioutil.ReadFile(f)
if err != nil {
@@ -88,11 +88,21 @@ func (c *Checker) Stop() {
}
func (c *Checker) start() {
go func() {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
//log.Printf("checker started")
for {
//log.Printf("checker did")
for _, v := range c.data.Items() {
go func(item CheckerItem) {
+ defer func() {
+ if e := recover(); e != nil {
+ fmt.Printf("crashed:%s", string(debug.Stack()))
+ }
+ }()
if c.isNeedCheck(item) {
//log.Printf("check %s", item.Host)
var conn net.Conn
@@ -198,18 +208,18 @@ func (c *Checker) Add(domain, address string) {
}
type BasicAuth struct {
- data ConcurrentMap
+ data mapx.ConcurrentMap
authURL string
authOkCode int
authTimeout int
authRetry int
- dns *DomainResolver
+ dns *dnsx.DomainResolver
log *logger.Logger
}
-func NewBasicAuth(dns *DomainResolver, log *logger.Logger) BasicAuth {
+func NewBasicAuth(dns *dnsx.DomainResolver, log *logger.Logger) BasicAuth {
return BasicAuth{
- data: NewConcurrentMap(),
+ data: mapx.NewConcurrentMap(),
dns: dns,
log: log,
}
@@ -528,53 +538,15 @@ func (req *HTTPRequest) addPortIfNot() (newHost string) {
return
}
-type OutConn struct {
- dur int
- typ string
- certBytes []byte
- keyBytes []byte
- caCertBytes []byte
- kcp kcpcfg.KCPConfigArgs
- address string
- timeout int
-}
-
-func NewOutConn(dur int, typ string, kcp kcpcfg.KCPConfigArgs, certBytes, keyBytes, caCertBytes []byte, address string, timeout int) (op OutConn) {
- return OutConn{
- dur: dur,
- typ: typ,
- certBytes: certBytes,
- keyBytes: keyBytes,
- caCertBytes: caCertBytes,
- kcp: kcp,
- address: address,
- timeout: timeout,
- }
-}
-func (op *OutConn) Get() (conn net.Conn, err error) {
- if op.typ == "tls" {
- var _conn tls.Conn
- _conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes, op.caCertBytes)
- if err == nil {
- conn = net.Conn(&_conn)
- }
- } else if op.typ == "kcp" {
- conn, err = ConnectKCPHost(op.address, op.kcp)
- } else {
- conn, err = ConnectHost(op.address, op.timeout)
- }
- return
-}
-
type ConnManager struct {
- pool ConcurrentMap
+ pool mapx.ConcurrentMap
l *sync.Mutex
log *logger.Logger
}
func NewConnManager(log *logger.Logger) ConnManager {
cm := ConnManager{
- pool: NewConcurrentMap(),
+ pool: mapx.NewConcurrentMap(),
l: &sync.Mutex{},
log: log,
}
@@ -582,11 +554,11 @@ func NewConnManager(log *logger.Logger) ConnManager {
}
func (cm *ConnManager) Add(key, ID string, conn *net.Conn) {
cm.pool.Upsert(key, nil, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} {
- var conns ConcurrentMap
+ var conns mapx.ConcurrentMap
if !exist {
- conns = NewConcurrentMap()
+ conns = mapx.NewConcurrentMap()
} else {
- conns = valueInMap.(ConcurrentMap)
+ conns = valueInMap.(mapx.ConcurrentMap)
}
if conns.Has(ID) {
v, _ := conns.Get(ID)
@@ -598,9 +570,9 @@ func (cm *ConnManager) Add(key, ID string, conn *net.Conn) {
})
}
func (cm *ConnManager) Remove(key string) {
- var conns ConcurrentMap
+ var conns mapx.ConcurrentMap
if v, ok := cm.pool.Get(key); ok {
- conns = v.(ConcurrentMap)
+ conns = v.(mapx.ConcurrentMap)
conns.IterCb(func(key string, v interface{}) {
CloseConn(v.(*net.Conn))
})
@@ -611,9 +583,9 @@ func (cm *ConnManager) Remove(key string) {
func (cm *ConnManager) RemoveOne(key string, ID string) {
defer cm.l.Unlock()
cm.l.Lock()
- var conns ConcurrentMap
+ var conns mapx.ConcurrentMap
if v, ok := cm.pool.Get(key); ok {
- conns = v.(ConcurrentMap)
+ conns = v.(mapx.ConcurrentMap)
if conns.Has(ID) {
v, _ := conns.Get(ID)
(*v.(*net.Conn)).Close()
@@ -631,11 +603,11 @@ func (cm *ConnManager) RemoveAll() {
type ClientKeyRouter struct {
keyChan chan string
- ctrl *ConcurrentMap
+ ctrl *mapx.ConcurrentMap
lock *sync.Mutex
}
-func NewClientKeyRouter(ctrl *ConcurrentMap, size int) ClientKeyRouter {
+func NewClientKeyRouter(ctrl *mapx.ConcurrentMap, size int) ClientKeyRouter {
return ClientKeyRouter{
keyChan: make(chan string, size),
ctrl: ctrl,
@@ -671,103 +643,6 @@ func (c *ClientKeyRouter) GetKey() string {
}
-type DomainResolver struct {
- ttl int
- dnsAddrress string
- data ConcurrentMap
- log *logger.Logger
-}
-type DomainResolverItem struct {
- ip string
- domain string
- expiredAt int64
-}
-
-func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver {
- return DomainResolver{
- ttl: ttl,
- dnsAddrress: dnsAddrress,
- data: NewConcurrentMap(),
- log: log,
- }
-}
-func (a *DomainResolver) MustResolve(address string) (ip string) {
- ip, _ = a.Resolve(address)
- return
-}
-func (a *DomainResolver) Resolve(address string) (ip string, err error) {
- domain := address
- port := ""
- fromCache := "false"
- defer func() {
- if port != "" {
- ip = net.JoinHostPort(ip, port)
- }
- a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache)
- //a.PrintData()
- }()
- if strings.Contains(domain, ":") {
- domain, port, err = net.SplitHostPort(domain)
- if err != nil {
- return
- }
- }
- if net.ParseIP(domain) != nil {
- ip = domain
- fromCache = "ip ignore"
- return
- }
- item, ok := a.data.Get(domain)
- if ok {
- //log.Println("find ", domain)
- if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
- ip = (*item.(*DomainResolverItem)).ip
- fromCache = "true"
- //log.Println("from cache ", domain)
- return
- }
- } else {
- item = &DomainResolverItem{
- domain: domain,
- }
-
- }
- c := new(dns.Client)
- c.DialTimeout = time.Millisecond * 5000
- c.ReadTimeout = time.Millisecond * 5000
- c.WriteTimeout = time.Millisecond * 5000
- m := new(dns.Msg)
- m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
- m.RecursionDesired = true
- r, _, err := c.Exchange(m, a.dnsAddrress)
- if r == nil {
- return
- }
- if r.Rcode != dns.RcodeSuccess {
- err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
- return
- }
- for _, answer := range r.Answer {
- if answer.Header().Rrtype == dns.TypeA {
- info := strings.Fields(answer.String())
- if len(info) >= 5 {
- ip = info[4]
- _item := item.(*DomainResolverItem)
- (*_item).expiredAt = time.Now().Unix() + int64(a.ttl)
- (*_item).ip = ip
- a.data.Set(domain, item)
- return
- }
- }
- }
- return
-}
-func (a *DomainResolver) PrintData() {
- for k, item := range a.data.Items() {
- d := item.(*DomainResolverItem)
- a.log.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt)
- }
-}
func NewCompStream(conn net.Conn) *CompStream {
c := new(CompStream)
c.conn = conn