diff options
| -rw-r--r-- | common/websocketconn/websocketconn.go | 135 | ||||
| -rw-r--r-- | common/websocketconn/websocketconn_test.go | 235 |
2 files changed, 327 insertions, 43 deletions
diff --git a/common/websocketconn/websocketconn.go b/common/websocketconn/websocketconn.go index b87e657..73c2b25 100644 --- a/common/websocketconn/websocketconn.go +++ b/common/websocketconn/websocketconn.go @@ -7,64 +7,113 @@ import ( "github.com/gorilla/websocket" ) -// An abstraction that makes an underlying WebSocket connection look like an -// io.ReadWriteCloser. +// An abstraction that makes an underlying WebSocket connection look like a +// net.Conn. type Conn struct { - Ws *websocket.Conn - r io.Reader + *websocket.Conn + Reader io.Reader + Writer io.Writer } -// Implements io.Reader. func (conn *Conn) Read(b []byte) (n int, err error) { - var opCode int - if conn.r == nil { - // New message - var r io.Reader - for { - if opCode, r, err = conn.Ws.NextReader(); err != nil { - return - } - if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage { - continue - } + return conn.Reader.Read(b) +} - conn.r = r - break - } - } +func (conn *Conn) Write(b []byte) (n int, err error) { + return conn.Writer.Write(b) +} + +func (conn *Conn) Close() error { + // Ignore any error in trying to write a Close frame. + _ = conn.Conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second)) + return conn.Conn.Close() +} - n, err = conn.r.Read(b) - if err == io.EOF { - // Message finished - conn.r = nil - err = nil +func (conn *Conn) SetDeadline(t time.Time) error { + errRead := conn.Conn.SetReadDeadline(t) + errWrite := conn.Conn.SetWriteDeadline(t) + err := errRead + if err == nil { + err = errWrite } - return + return err } -// Implements io.Writer. -func (conn *Conn) Write(b []byte) (n int, err error) { - var w io.WriteCloser - if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil { - return +func readLoop(w io.Writer, ws *websocket.Conn) error { + for { + messageType, r, err := ws.NextReader() + if err != nil { + return err + } + if messageType != websocket.BinaryMessage && messageType != websocket.TextMessage { + continue + } + _, err = io.Copy(w, r) + if err != nil { + return err + } } - if n, err = w.Write(b); err != nil { - return + return nil +} + +func writeLoop(ws *websocket.Conn, r io.Reader) error { + for { + var buf [2048]byte + n, err := r.Read(buf[:]) + if err != nil { + return err + } + data := buf[:n] + w, err := ws.NextWriter(websocket.BinaryMessage) + if err != nil { + return err + } + n, err = w.Write(data) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } } - err = w.Close() - return } -// Implements io.Closer. -func (conn *Conn) Close() error { - // Ignore any error in trying to write a Close frame. - _ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second)) - return conn.Ws.Close() +// websocket.Conn methods start returning websocket.CloseError after the +// connection has been closed. We want to instead interpret that as io.EOF, just +// as you would find with a normal net.Conn. This only converts +// websocket.CloseErrors with known codes; other codes like CloseProtocolError +// and CloseAbnormalClosure will still be reported as anomalous. +func closeErrorToEOF(err error) error { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { + err = io.EOF + } + return err } // Create a new Conn. func New(ws *websocket.Conn) *Conn { - var conn Conn - conn.Ws = ws - return &conn + // Set up synchronous pipes to serialize reads and writes to the + // underlying websocket.Conn. + // + // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency + // "Connections support one concurrent reader and one concurrent writer. + // Applications are responsible for ensuring that no more than one + // goroutine calls the write methods (NextWriter, etc.) concurrently and + // that no more than one goroutine calls the read methods (NextReader, + // etc.) concurrently. The Close and WriteControl methods can be called + // concurrently with all other methods." + pr1, pw1 := io.Pipe() + go func() { + pw1.CloseWithError(closeErrorToEOF(readLoop(pw1, ws))) + }() + pr2, pw2 := io.Pipe() + go func() { + pr2.CloseWithError(closeErrorToEOF(writeLoop(ws, pr2))) + }() + return &Conn{ + Conn: ws, + Reader: pr1, + Writer: pw2, + } } diff --git a/common/websocketconn/websocketconn_test.go b/common/websocketconn/websocketconn_test.go new file mode 100644 index 0000000..9bc02ec --- /dev/null +++ b/common/websocketconn/websocketconn_test.go @@ -0,0 +1,235 @@ +package websocketconn + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// Returns a (server, client) pair of websocketconn.Conns. +func connPair() (*Conn, *Conn, error) { + // Will be assigned inside server.Handler. + var serverConn *Conn + + // Start up a web server to receive the request. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer ln.Close() + errCh := make(chan error) + server := http.Server{ + Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(*http.Request) bool { return true }, + } + ws, err := upgrader.Upgrade(rw, req, nil) + if err != nil { + errCh <- err + return + } + serverConn = New(ws) + close(errCh) + }), + } + defer server.Close() + go func() { + err := server.Serve(ln) + if err != nil && err != http.ErrServerClosed { + errCh <- err + } + }() + + // Make a request to the web server. + urlStr := (&url.URL{Scheme: "ws", Host: ln.Addr().String()}).String() + ws, _, err := (&websocket.Dialer{}).Dial(urlStr, nil) + if err != nil { + return nil, nil, err + } + clientConn := New(ws) + + // The server is finished when errCh is written to or closed. + err = <-errCh + if err != nil { + return nil, nil, err + } + return serverConn, clientConn, nil +} + +// Test that you can write in chunks and read the result concatenated. +func TestWrite(t *testing.T) { + tests := [][][]byte{ + {}, + {[]byte("foo")}, + {[]byte("foo"), []byte("bar")}, + {{}, []byte("foo"), {}, {}, []byte("bar")}, + } + + for _, test := range tests { + s, c, err := connPair() + if err != nil { + t.Fatal(err) + } + + // This is a little awkward because we need to read to and write + // from both ends of the Conn, and we need to do it in separate + // goroutines because otherwise a Write may block waiting for + // someone to Read it. Here we set up a loop in a separate + // goroutine, reading from the Conn s and writing to the dataCh + // and errCh channels, whose ultimate effect in the select loop + // below is like + // data, err := ioutil.ReadAll(s) + dataCh := make(chan []byte) + errCh := make(chan error) + go func() { + for { + var buf [1024]byte + n, err := s.Read(buf[:]) + if err != nil { + errCh <- err + return + } + p := make([]byte, n) + copy(p, buf[:]) + dataCh <- p + } + }() + + // Write the data to the client side of the Conn, one chunk at a + // time. + for i, chunk := range test { + n, err := c.Write(chunk) + if err != nil || n != len(chunk) { + t.Fatalf("%+q Write chunk %d: got (%d, %v), expected (%d, %v)", + test, i, n, err, len(chunk), nil) + } + } + // We cannot immediately c.Close here, because that closes the + // connection right away, without waiting for buffered data to + // be sent. + + // Pull data and err from the server goroutine above. + var data []byte + err = nil + loop: + for { + select { + case p := <-dataCh: + data = append(data, p...) + case err = <-errCh: + break loop + case <-time.After(100 * time.Millisecond): + break loop + } + } + s.Close() + c.Close() + + // Now data and err contain the result of reading everything + // from s. + expected := bytes.Join(test, []byte{}) + if err != nil || !bytes.Equal(data, expected) { + t.Fatalf("%+q ReadAll: got (%+q, %v), expected (%+q, %v)", + test, data, err, expected, nil) + } + } +} + +// Test that multiple goroutines may call Read on a Conn simultaneously. Run +// this with +// go test -race +func TestConcurrentRead(t *testing.T) { + s, c, err := connPair() + if err != nil { + t.Fatal(err) + } + defer s.Close() + + // Set up multiple threads reading from the same conn. + errCh := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) + for i := 0; i < 2; i++ { + go func() { + defer wg.Done() + _, err := io.Copy(ioutil.Discard, s) + if err != nil { + errCh <- err + } + }() + } + + // Write a bunch of data to the other end. + for i := 0; i < 2000; i++ { + _, err := c.Write([]byte(fmt.Sprintf("%d", i))) + if err != nil { + c.Close() + t.Fatalf("Write: %v", err) + } + } + c.Close() + + wg.Wait() + close(errCh) + + err = <-errCh + if err != nil { + t.Fatalf("Read: %v", err) + } +} + +// Test that multiple goroutines may call Write on a Conn simultaneously. Run +// this with +// go test -race +func TestConcurrentWrite(t *testing.T) { + s, c, err := connPair() + if err != nil { + t.Fatal(err) + } + + // Set up multiple threads writing to the same conn. + errCh := make(chan error, 3) + var wg sync.WaitGroup + wg.Add(2) + for i := 0; i < 2; i++ { + go func() { + defer wg.Done() + for j := 0; j < 1000; j++ { + _, err := fmt.Fprintf(s, "%d", j) + if err != nil { + errCh <- err + break + } + } + }() + } + go func() { + wg.Wait() + err := s.Close() + if err != nil { + errCh <- err + } + close(errCh) + }() + + // Read from the other end. + _, err = io.Copy(ioutil.Discard, c) + c.Close() + if err != nil { + t.Fatalf("Read: %v", err) + } + + err = <-errCh + if err != nil { + t.Fatalf("Write: %v", err) + } +} |
