diff options
| -rw-r--r-- | client/lib/snowflake.go | 1 | ||||
| -rw-r--r-- | client/snowflake.go | 11 | ||||
| -rw-r--r-- | common/websocketconn/websocketconn.go | 16 | ||||
| -rw-r--r-- | proxy-go/snowflake.go | 4 | ||||
| -rw-r--r-- | server/server.go | 12 | ||||
| -rw-r--r-- | server/server_test.go | 8 |
6 files changed, 27 insertions, 25 deletions
diff --git a/client/lib/snowflake.go b/client/lib/snowflake.go index a27c6a5..9ab6fc6 100644 --- a/client/lib/snowflake.go +++ b/client/lib/snowflake.go @@ -25,7 +25,6 @@ func Handler(socks SocksConnector, snowflakes SnowflakeCollector) error { return errors.New("handler: Received invalid Snowflake") } - defer socks.Close() defer snowflake.Close() log.Println("---- Handler: snowflake assigned ----") err := socks.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0}) diff --git a/client/snowflake.go b/client/snowflake.go index 7cb9451..af416be 100644 --- a/client/snowflake.go +++ b/client/snowflake.go @@ -57,10 +57,13 @@ func socksAcceptLoop(ln *pt.SocksListener, snowflakes sf.SnowflakeCollector) { break } log.Printf("SOCKS accepted: %v", conn.Req) - err = sf.Handler(conn, snowflakes) - if err != nil { - log.Printf("handler error: %s", err) - } + go func() { + defer conn.Close() + err = sf.Handler(conn, snowflakes) + if err != nil { + log.Printf("handler error: %s", err) + } + }() } } diff --git a/common/websocketconn/websocketconn.go b/common/websocketconn/websocketconn.go index 7e12abf..b87e657 100644 --- a/common/websocketconn/websocketconn.go +++ b/common/websocketconn/websocketconn.go @@ -9,13 +9,13 @@ import ( // An abstraction that makes an underlying WebSocket connection look like an // io.ReadWriteCloser. -type WebSocketConn struct { +type Conn struct { Ws *websocket.Conn r io.Reader } // Implements io.Reader. -func (conn *WebSocketConn) Read(b []byte) (n int, err error) { +func (conn *Conn) Read(b []byte) (n int, err error) { var opCode int if conn.r == nil { // New message @@ -43,7 +43,7 @@ func (conn *WebSocketConn) Read(b []byte) (n int, err error) { } // Implements io.Writer. -func (conn *WebSocketConn) Write(b []byte) (n int, err error) { +func (conn *Conn) Write(b []byte) (n int, err error) { var w io.WriteCloser if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil { return @@ -56,15 +56,15 @@ func (conn *WebSocketConn) Write(b []byte) (n int, err error) { } // Implements io.Closer. -func (conn *WebSocketConn) Close() error { +func (conn *Conn) Close() error { // Ignore any error in trying to write a Close frame. _ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second)) return conn.Ws.Close() } -// Create a new WebSocketConn. -func NewWebSocketConn(ws *websocket.Conn) WebSocketConn { - var conn WebSocketConn +// Create a new Conn. +func New(ws *websocket.Conn) *Conn { + var conn Conn conn.Ws = ws - return conn + return &conn } diff --git a/proxy-go/snowflake.go b/proxy-go/snowflake.go index dce7b70..e964a07 100644 --- a/proxy-go/snowflake.go +++ b/proxy-go/snowflake.go @@ -285,10 +285,10 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { log.Printf("error dialing relay: %s", err) return } - wsConn := websocketconn.NewWebSocketConn(ws) + wsConn := websocketconn.New(ws) log.Printf("connected to relay") defer wsConn.Close() - CopyLoop(conn, &wsConn) + CopyLoop(conn, wsConn) log.Printf("datachannelHandler ends") } diff --git a/server/server.go b/server/server.go index 5ed56d3..a7aa444 100644 --- a/server/server.go +++ b/server/server.go @@ -52,7 +52,7 @@ additional HTTP listener on port 80 to work with ACME. } // Copy from WebSocket to socket and vice versa. -func proxy(local *net.TCPConn, conn *websocketconn.WebSocketConn) { +func proxy(local *net.TCPConn, conn *websocketconn.Conn) { var wg sync.WaitGroup wg.Add(2) @@ -94,7 +94,9 @@ func clientAddr(clientIPParam string) string { return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String() } -var upgrader = websocket.Upgrader{} +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, +} type HTTPHandler struct{} @@ -105,7 +107,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - conn := websocketconn.NewWebSocketConn(ws) + conn := websocketconn.New(ws) defer conn.Close() // Pass the address of client as the remote address of incoming connection @@ -123,7 +125,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer or.Close() - proxy(or, &conn) + proxy(or, conn) } func initServer(addr *net.TCPAddr, @@ -139,8 +141,6 @@ func initServer(addr *net.TCPAddr, return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port) } - upgrader.CheckOrigin = func(r *http.Request) bool { return true } - var handler HTTPHandler server := &http.Server{ Addr: addr.String(), diff --git a/server/server_test.go b/server/server_test.go index 7a72014..d4ada6e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -60,13 +60,13 @@ type StubHandler struct{} func (handler *StubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ws, _ := upgrader.Upgrade(w, r, nil) - conn := websocketconn.NewWebSocketConn(ws) + conn := websocketconn.New(ws) defer conn.Close() //dial stub OR or, _ := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.ParseIP("localhost"), Port: 8889}) - proxy(or, &conn) + proxy(or, conn) } func Test(t *testing.T) { @@ -90,7 +90,7 @@ func Test(t *testing.T) { So(err, ShouldBeNil) ws, _, err := websocket.DefaultDialer.Dial("ws://localhost:8888", nil) - wsConn := websocketconn.NewWebSocketConn(ws) + wsConn := websocketconn.New(ws) So(err, ShouldEqual, nil) So(wsConn, ShouldNotEqual, nil) @@ -133,7 +133,7 @@ func Test(t *testing.T) { ws, _, err := websocket.DefaultDialer.Dial("ws://localhost:8888", nil) So(err, ShouldEqual, nil) - wsConn := websocketconn.NewWebSocketConn(ws) + wsConn := websocketconn.New(ws) So(wsConn, ShouldNotEqual, nil) wsConn.Write([]byte("Hello")) |
