summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/websocketconn/websocketconn.go135
-rw-r--r--common/websocketconn/websocketconn_test.go235
2 files changed, 327 insertions, 43 deletions
diff --git a/common/websocketconn/websocketconn.go b/common/websocketconn/websocketconn.go
index b87e657..73c2b25 100644
--- a/common/websocketconn/websocketconn.go
+++ b/common/websocketconn/websocketconn.go
@@ -7,64 +7,113 @@ import (
"github.com/gorilla/websocket"
)
-// An abstraction that makes an underlying WebSocket connection look like an
-// io.ReadWriteCloser.
+// An abstraction that makes an underlying WebSocket connection look like a
+// net.Conn.
type Conn struct {
- Ws *websocket.Conn
- r io.Reader
+ *websocket.Conn
+ Reader io.Reader
+ Writer io.Writer
}
-// Implements io.Reader.
func (conn *Conn) Read(b []byte) (n int, err error) {
- var opCode int
- if conn.r == nil {
- // New message
- var r io.Reader
- for {
- if opCode, r, err = conn.Ws.NextReader(); err != nil {
- return
- }
- if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
- continue
- }
+ return conn.Reader.Read(b)
+}
- conn.r = r
- break
- }
- }
+func (conn *Conn) Write(b []byte) (n int, err error) {
+ return conn.Writer.Write(b)
+}
+
+func (conn *Conn) Close() error {
+ // Ignore any error in trying to write a Close frame.
+ _ = conn.Conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second))
+ return conn.Conn.Close()
+}
- n, err = conn.r.Read(b)
- if err == io.EOF {
- // Message finished
- conn.r = nil
- err = nil
+func (conn *Conn) SetDeadline(t time.Time) error {
+ errRead := conn.Conn.SetReadDeadline(t)
+ errWrite := conn.Conn.SetWriteDeadline(t)
+ err := errRead
+ if err == nil {
+ err = errWrite
}
- return
+ return err
}
-// Implements io.Writer.
-func (conn *Conn) Write(b []byte) (n int, err error) {
- var w io.WriteCloser
- if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
- return
+func readLoop(w io.Writer, ws *websocket.Conn) error {
+ for {
+ messageType, r, err := ws.NextReader()
+ if err != nil {
+ return err
+ }
+ if messageType != websocket.BinaryMessage && messageType != websocket.TextMessage {
+ continue
+ }
+ _, err = io.Copy(w, r)
+ if err != nil {
+ return err
+ }
}
- if n, err = w.Write(b); err != nil {
- return
+ return nil
+}
+
+func writeLoop(ws *websocket.Conn, r io.Reader) error {
+ for {
+ var buf [2048]byte
+ n, err := r.Read(buf[:])
+ if err != nil {
+ return err
+ }
+ data := buf[:n]
+ w, err := ws.NextWriter(websocket.BinaryMessage)
+ if err != nil {
+ return err
+ }
+ n, err = w.Write(data)
+ if err != nil {
+ return err
+ }
+ err = w.Close()
+ if err != nil {
+ return err
+ }
}
- err = w.Close()
- return
}
-// Implements io.Closer.
-func (conn *Conn) Close() error {
- // Ignore any error in trying to write a Close frame.
- _ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second))
- return conn.Ws.Close()
+// websocket.Conn methods start returning websocket.CloseError after the
+// connection has been closed. We want to instead interpret that as io.EOF, just
+// as you would find with a normal net.Conn. This only converts
+// websocket.CloseErrors with known codes; other codes like CloseProtocolError
+// and CloseAbnormalClosure will still be reported as anomalous.
+func closeErrorToEOF(err error) error {
+ if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
+ err = io.EOF
+ }
+ return err
}
// Create a new Conn.
func New(ws *websocket.Conn) *Conn {
- var conn Conn
- conn.Ws = ws
- return &conn
+ // Set up synchronous pipes to serialize reads and writes to the
+ // underlying websocket.Conn.
+ //
+ // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
+ // "Connections support one concurrent reader and one concurrent writer.
+ // Applications are responsible for ensuring that no more than one
+ // goroutine calls the write methods (NextWriter, etc.) concurrently and
+ // that no more than one goroutine calls the read methods (NextReader,
+ // etc.) concurrently. The Close and WriteControl methods can be called
+ // concurrently with all other methods."
+ pr1, pw1 := io.Pipe()
+ go func() {
+ pw1.CloseWithError(closeErrorToEOF(readLoop(pw1, ws)))
+ }()
+ pr2, pw2 := io.Pipe()
+ go func() {
+ pr2.CloseWithError(closeErrorToEOF(writeLoop(ws, pr2)))
+ }()
+ return &Conn{
+ Conn: ws,
+ Reader: pr1,
+ Writer: pw2,
+ }
}
diff --git a/common/websocketconn/websocketconn_test.go b/common/websocketconn/websocketconn_test.go
new file mode 100644
index 0000000..9bc02ec
--- /dev/null
+++ b/common/websocketconn/websocketconn_test.go
@@ -0,0 +1,235 @@
+package websocketconn
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/url"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+// Returns a (server, client) pair of websocketconn.Conns.
+func connPair() (*Conn, *Conn, error) {
+ // Will be assigned inside server.Handler.
+ var serverConn *Conn
+
+ // Start up a web server to receive the request.
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, nil, err
+ }
+ defer ln.Close()
+ errCh := make(chan error)
+ server := http.Server{
+ Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ upgrader := websocket.Upgrader{
+ CheckOrigin: func(*http.Request) bool { return true },
+ }
+ ws, err := upgrader.Upgrade(rw, req, nil)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ serverConn = New(ws)
+ close(errCh)
+ }),
+ }
+ defer server.Close()
+ go func() {
+ err := server.Serve(ln)
+ if err != nil && err != http.ErrServerClosed {
+ errCh <- err
+ }
+ }()
+
+ // Make a request to the web server.
+ urlStr := (&url.URL{Scheme: "ws", Host: ln.Addr().String()}).String()
+ ws, _, err := (&websocket.Dialer{}).Dial(urlStr, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+ clientConn := New(ws)
+
+ // The server is finished when errCh is written to or closed.
+ err = <-errCh
+ if err != nil {
+ return nil, nil, err
+ }
+ return serverConn, clientConn, nil
+}
+
+// Test that you can write in chunks and read the result concatenated.
+func TestWrite(t *testing.T) {
+ tests := [][][]byte{
+ {},
+ {[]byte("foo")},
+ {[]byte("foo"), []byte("bar")},
+ {{}, []byte("foo"), {}, {}, []byte("bar")},
+ }
+
+ for _, test := range tests {
+ s, c, err := connPair()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // This is a little awkward because we need to read to and write
+ // from both ends of the Conn, and we need to do it in separate
+ // goroutines because otherwise a Write may block waiting for
+ // someone to Read it. Here we set up a loop in a separate
+ // goroutine, reading from the Conn s and writing to the dataCh
+ // and errCh channels, whose ultimate effect in the select loop
+ // below is like
+ // data, err := ioutil.ReadAll(s)
+ dataCh := make(chan []byte)
+ errCh := make(chan error)
+ go func() {
+ for {
+ var buf [1024]byte
+ n, err := s.Read(buf[:])
+ if err != nil {
+ errCh <- err
+ return
+ }
+ p := make([]byte, n)
+ copy(p, buf[:])
+ dataCh <- p
+ }
+ }()
+
+ // Write the data to the client side of the Conn, one chunk at a
+ // time.
+ for i, chunk := range test {
+ n, err := c.Write(chunk)
+ if err != nil || n != len(chunk) {
+ t.Fatalf("%+q Write chunk %d: got (%d, %v), expected (%d, %v)",
+ test, i, n, err, len(chunk), nil)
+ }
+ }
+ // We cannot immediately c.Close here, because that closes the
+ // connection right away, without waiting for buffered data to
+ // be sent.
+
+ // Pull data and err from the server goroutine above.
+ var data []byte
+ err = nil
+ loop:
+ for {
+ select {
+ case p := <-dataCh:
+ data = append(data, p...)
+ case err = <-errCh:
+ break loop
+ case <-time.After(100 * time.Millisecond):
+ break loop
+ }
+ }
+ s.Close()
+ c.Close()
+
+ // Now data and err contain the result of reading everything
+ // from s.
+ expected := bytes.Join(test, []byte{})
+ if err != nil || !bytes.Equal(data, expected) {
+ t.Fatalf("%+q ReadAll: got (%+q, %v), expected (%+q, %v)",
+ test, data, err, expected, nil)
+ }
+ }
+}
+
+// Test that multiple goroutines may call Read on a Conn simultaneously. Run
+// this with
+// go test -race
+func TestConcurrentRead(t *testing.T) {
+ s, c, err := connPair()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer s.Close()
+
+ // Set up multiple threads reading from the same conn.
+ errCh := make(chan error, 2)
+ var wg sync.WaitGroup
+ wg.Add(2)
+ for i := 0; i < 2; i++ {
+ go func() {
+ defer wg.Done()
+ _, err := io.Copy(ioutil.Discard, s)
+ if err != nil {
+ errCh <- err
+ }
+ }()
+ }
+
+ // Write a bunch of data to the other end.
+ for i := 0; i < 2000; i++ {
+ _, err := c.Write([]byte(fmt.Sprintf("%d", i)))
+ if err != nil {
+ c.Close()
+ t.Fatalf("Write: %v", err)
+ }
+ }
+ c.Close()
+
+ wg.Wait()
+ close(errCh)
+
+ err = <-errCh
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+}
+
+// Test that multiple goroutines may call Write on a Conn simultaneously. Run
+// this with
+// go test -race
+func TestConcurrentWrite(t *testing.T) {
+ s, c, err := connPair()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Set up multiple threads writing to the same conn.
+ errCh := make(chan error, 3)
+ var wg sync.WaitGroup
+ wg.Add(2)
+ for i := 0; i < 2; i++ {
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 1000; j++ {
+ _, err := fmt.Fprintf(s, "%d", j)
+ if err != nil {
+ errCh <- err
+ break
+ }
+ }
+ }()
+ }
+ go func() {
+ wg.Wait()
+ err := s.Close()
+ if err != nil {
+ errCh <- err
+ }
+ close(errCh)
+ }()
+
+ // Read from the other end.
+ _, err = io.Copy(ioutil.Discard, c)
+ c.Close()
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+
+ err = <-errCh
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+}