summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client/lib/snowflake.go1
-rw-r--r--client/snowflake.go11
-rw-r--r--common/websocketconn/websocketconn.go16
-rw-r--r--proxy-go/snowflake.go4
-rw-r--r--server/server.go12
-rw-r--r--server/server_test.go8
6 files changed, 27 insertions, 25 deletions
diff --git a/client/lib/snowflake.go b/client/lib/snowflake.go
index a27c6a5..9ab6fc6 100644
--- a/client/lib/snowflake.go
+++ b/client/lib/snowflake.go
@@ -25,7 +25,6 @@ func Handler(socks SocksConnector, snowflakes SnowflakeCollector) error {
return errors.New("handler: Received invalid Snowflake")
}
- defer socks.Close()
defer snowflake.Close()
log.Println("---- Handler: snowflake assigned ----")
err := socks.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0})
diff --git a/client/snowflake.go b/client/snowflake.go
index 7cb9451..af416be 100644
--- a/client/snowflake.go
+++ b/client/snowflake.go
@@ -57,10 +57,13 @@ func socksAcceptLoop(ln *pt.SocksListener, snowflakes sf.SnowflakeCollector) {
break
}
log.Printf("SOCKS accepted: %v", conn.Req)
- err = sf.Handler(conn, snowflakes)
- if err != nil {
- log.Printf("handler error: %s", err)
- }
+ go func() {
+ defer conn.Close()
+ err = sf.Handler(conn, snowflakes)
+ if err != nil {
+ log.Printf("handler error: %s", err)
+ }
+ }()
}
}
diff --git a/common/websocketconn/websocketconn.go b/common/websocketconn/websocketconn.go
index 7e12abf..b87e657 100644
--- a/common/websocketconn/websocketconn.go
+++ b/common/websocketconn/websocketconn.go
@@ -9,13 +9,13 @@ import (
// An abstraction that makes an underlying WebSocket connection look like an
// io.ReadWriteCloser.
-type WebSocketConn struct {
+type Conn struct {
Ws *websocket.Conn
r io.Reader
}
// Implements io.Reader.
-func (conn *WebSocketConn) Read(b []byte) (n int, err error) {
+func (conn *Conn) Read(b []byte) (n int, err error) {
var opCode int
if conn.r == nil {
// New message
@@ -43,7 +43,7 @@ func (conn *WebSocketConn) Read(b []byte) (n int, err error) {
}
// Implements io.Writer.
-func (conn *WebSocketConn) Write(b []byte) (n int, err error) {
+func (conn *Conn) Write(b []byte) (n int, err error) {
var w io.WriteCloser
if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
return
@@ -56,15 +56,15 @@ func (conn *WebSocketConn) Write(b []byte) (n int, err error) {
}
// Implements io.Closer.
-func (conn *WebSocketConn) Close() error {
+func (conn *Conn) Close() error {
// Ignore any error in trying to write a Close frame.
_ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second))
return conn.Ws.Close()
}
-// Create a new WebSocketConn.
-func NewWebSocketConn(ws *websocket.Conn) WebSocketConn {
- var conn WebSocketConn
+// Create a new Conn.
+func New(ws *websocket.Conn) *Conn {
+ var conn Conn
conn.Ws = ws
- return conn
+ return &conn
}
diff --git a/proxy-go/snowflake.go b/proxy-go/snowflake.go
index dce7b70..e964a07 100644
--- a/proxy-go/snowflake.go
+++ b/proxy-go/snowflake.go
@@ -285,10 +285,10 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
log.Printf("error dialing relay: %s", err)
return
}
- wsConn := websocketconn.NewWebSocketConn(ws)
+ wsConn := websocketconn.New(ws)
log.Printf("connected to relay")
defer wsConn.Close()
- CopyLoop(conn, &wsConn)
+ CopyLoop(conn, wsConn)
log.Printf("datachannelHandler ends")
}
diff --git a/server/server.go b/server/server.go
index 5ed56d3..a7aa444 100644
--- a/server/server.go
+++ b/server/server.go
@@ -52,7 +52,7 @@ additional HTTP listener on port 80 to work with ACME.
}
// Copy from WebSocket to socket and vice versa.
-func proxy(local *net.TCPConn, conn *websocketconn.WebSocketConn) {
+func proxy(local *net.TCPConn, conn *websocketconn.Conn) {
var wg sync.WaitGroup
wg.Add(2)
@@ -94,7 +94,9 @@ func clientAddr(clientIPParam string) string {
return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String()
}
-var upgrader = websocket.Upgrader{}
+var upgrader = websocket.Upgrader{
+ CheckOrigin: func(r *http.Request) bool { return true },
+}
type HTTPHandler struct{}
@@ -105,7 +107,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
- conn := websocketconn.NewWebSocketConn(ws)
+ conn := websocketconn.New(ws)
defer conn.Close()
// Pass the address of client as the remote address of incoming connection
@@ -123,7 +125,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
defer or.Close()
- proxy(or, &conn)
+ proxy(or, conn)
}
func initServer(addr *net.TCPAddr,
@@ -139,8 +141,6 @@ func initServer(addr *net.TCPAddr,
return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port)
}
- upgrader.CheckOrigin = func(r *http.Request) bool { return true }
-
var handler HTTPHandler
server := &http.Server{
Addr: addr.String(),
diff --git a/server/server_test.go b/server/server_test.go
index 7a72014..d4ada6e 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -60,13 +60,13 @@ type StubHandler struct{}
func (handler *StubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws, _ := upgrader.Upgrade(w, r, nil)
- conn := websocketconn.NewWebSocketConn(ws)
+ conn := websocketconn.New(ws)
defer conn.Close()
//dial stub OR
or, _ := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.ParseIP("localhost"), Port: 8889})
- proxy(or, &conn)
+ proxy(or, conn)
}
func Test(t *testing.T) {
@@ -90,7 +90,7 @@ func Test(t *testing.T) {
So(err, ShouldBeNil)
ws, _, err := websocket.DefaultDialer.Dial("ws://localhost:8888", nil)
- wsConn := websocketconn.NewWebSocketConn(ws)
+ wsConn := websocketconn.New(ws)
So(err, ShouldEqual, nil)
So(wsConn, ShouldNotEqual, nil)
@@ -133,7 +133,7 @@ func Test(t *testing.T) {
ws, _, err := websocket.DefaultDialer.Dial("ws://localhost:8888", nil)
So(err, ShouldEqual, nil)
- wsConn := websocketconn.NewWebSocketConn(ws)
+ wsConn := websocketconn.New(ws)
So(wsConn, ShouldNotEqual, nil)
wsConn.Write([]byte("Hello"))