Browse Source

ethstats: avoid concurrent write on websocket (#21404)

Fixes #21403
Martin Holst Swende 5 years ago
parent
commit
82a9e11058
1 changed files with 57 additions and 11 deletions
  1. 57 11
      ethstats/ethstats.go

+ 57 - 11
ethstats/ethstats.go

@@ -28,6 +28,7 @@ import (
 	"runtime"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -93,6 +94,49 @@ type Service struct {
 
 	pongCh chan struct{} // Pong notifications are fed into this channel
 	histCh chan []uint64 // History request block numbers are fed into this channel
+
+}
+
+// connWrapper is a wrapper to prevent concurrent-write or concurrent-read on the
+// websocket.
+// From Gorilla websocket docs:
+// Connections support one concurrent reader and one concurrent writer.
+// Applications are responsible for ensuring that no more than one goroutine calls the write methods
+// - NextWriter, SetWriteDeadline, WriteMessage, WriteJSON, EnableWriteCompression, SetCompressionLevel
+// concurrently and that no more than one goroutine calls the read methods
+// - NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler
+// concurrently.
+// The Close and WriteControl methods can be called concurrently with all other methods.
+//
+// The connWrapper uses a single mutex for both reading and writing.
+type connWrapper struct {
+	conn *websocket.Conn
+	mu   sync.Mutex
+}
+
+func newConnectionWrapper(conn *websocket.Conn) *connWrapper {
+	return &connWrapper{conn: conn}
+}
+
+// WriteJSON wraps corresponding method on the websocket but is safe for concurrent calling
+func (w *connWrapper) WriteJSON(v interface{}) error {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	return w.conn.WriteJSON(v)
+}
+
+// ReadJSON wraps corresponding method on the websocket but is safe for concurrent calling
+func (w *connWrapper) ReadJSON(v interface{}) error {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	return w.conn.ReadJSON(v)
+}
+
+// Close wraps corresponding method on the websocket but is safe for concurrent calling
+func (w *connWrapper) Close() error {
+	// The Close and WriteControl methods can be called concurrently with all other methods,
+	// so the mutex is not used here
+	return w.conn.Close()
 }
 
 // New returns a monitoring service ready for stats reporting.
@@ -204,17 +248,19 @@ func (s *Service) loop() {
 		case <-errTimer.C:
 			// Establish a websocket connection to the server on any supported URL
 			var (
-				conn *websocket.Conn
+				conn *connWrapper
 				err  error
 			)
 			dialer := websocket.Dialer{HandshakeTimeout: 5 * time.Second}
 			header := make(http.Header)
 			header.Set("origin", "http://localhost")
 			for _, url := range urls {
-				conn, _, err = dialer.Dial(url, header)
-				if err == nil {
+				c, _, e := dialer.Dial(url, header)
+				if e == nil {
+					conn = newConnectionWrapper(c)
 					break
 				}
+				err = e
 			}
 			if err != nil {
 				log.Warn("Stats server unreachable", "err", err)
@@ -282,7 +328,7 @@ func (s *Service) loop() {
 // from the network socket. If any of them match an active request, it forwards
 // it, if they themselves are requests it initiates a reply, and lastly it drops
 // unknown packets.
-func (s *Service) readLoop(conn *websocket.Conn) {
+func (s *Service) readLoop(conn *connWrapper) {
 	// If the read loop exists, close the connection
 	defer conn.Close()
 
@@ -391,7 +437,7 @@ type authMsg struct {
 }
 
 // login tries to authorize the client at the remote server.
-func (s *Service) login(conn *websocket.Conn) error {
+func (s *Service) login(conn *connWrapper) error {
 	// Construct and send the login authentication
 	infos := s.server.NodeInfo()
 
@@ -436,7 +482,7 @@ func (s *Service) login(conn *websocket.Conn) error {
 // report collects all possible data to report and send it to the stats server.
 // This should only be used on reconnects or rarely to avoid overloading the
 // server. Use the individual methods for reporting subscribed events.
-func (s *Service) report(conn *websocket.Conn) error {
+func (s *Service) report(conn *connWrapper) error {
 	if err := s.reportLatency(conn); err != nil {
 		return err
 	}
@@ -454,7 +500,7 @@ func (s *Service) report(conn *websocket.Conn) error {
 
 // reportLatency sends a ping request to the server, measures the RTT time and
 // finally sends a latency update.
-func (s *Service) reportLatency(conn *websocket.Conn) error {
+func (s *Service) reportLatency(conn *connWrapper) error {
 	// Send the current time to the ethstats server
 	start := time.Now()
 
@@ -523,7 +569,7 @@ func (s uncleStats) MarshalJSON() ([]byte, error) {
 }
 
 // reportBlock retrieves the current chain head and reports it to the stats server.
-func (s *Service) reportBlock(conn *websocket.Conn, block *types.Block) error {
+func (s *Service) reportBlock(conn *connWrapper, block *types.Block) error {
 	// Gather the block details from the header or block chain
 	details := s.assembleBlockStats(block)
 
@@ -598,7 +644,7 @@ func (s *Service) assembleBlockStats(block *types.Block) *blockStats {
 
 // reportHistory retrieves the most recent batch of blocks and reports it to the
 // stats server.
-func (s *Service) reportHistory(conn *websocket.Conn, list []uint64) error {
+func (s *Service) reportHistory(conn *connWrapper, list []uint64) error {
 	// Figure out the indexes that need reporting
 	indexes := make([]uint64, 0, historyUpdateRange)
 	if len(list) > 0 {
@@ -660,7 +706,7 @@ type pendStats struct {
 
 // reportPending retrieves the current number of pending transactions and reports
 // it to the stats server.
-func (s *Service) reportPending(conn *websocket.Conn) error {
+func (s *Service) reportPending(conn *connWrapper) error {
 	// Retrieve the pending count from the local blockchain
 	pending, _ := s.backend.Stats()
 	// Assemble the transaction stats and send it to the server
@@ -691,7 +737,7 @@ type nodeStats struct {
 
 // reportStats retrieves various stats about the node at the networking and
 // mining layer and reports it to the stats server.
-func (s *Service) reportStats(conn *websocket.Conn) error {
+func (s *Service) reportStats(conn *connWrapper) error {
 	// Gather the syncing and mining infos from the local miner instance
 	var (
 		mining   bool