Browse Source

cmd/faucet: fix websocket race regression after switching to gorilla

Péter Szilágyi 4 years ago
parent
commit
44208d9258
1 changed files with 35 additions and 24 deletions
  1. 35 24
      cmd/faucet/faucet.go

+ 35 - 24
cmd/faucet/faucet.go

@@ -213,7 +213,7 @@ type faucet struct {
 	nonce    uint64             // Current pending nonce of the faucet
 	price    *big.Int           // Current gas price to issue funds with
 
-	conns    []*websocket.Conn    // Currently live websocket connections
+	conns    []*wsConn            // Currently live websocket connections
 	timeouts map[string]time.Time // History of users and their funding timeouts
 	reqs     []*request           // Currently pending funding requests
 	update   chan struct{}        // Channel to signal request updates
@@ -221,6 +221,13 @@ type faucet struct {
 	lock sync.RWMutex // Lock protecting the faucet's internals
 }
 
+// wsConn wraps a websocket connection with a write mutex as the underlying
+// websocket library does not synchronize access to the stream.
+type wsConn struct {
+	conn  *websocket.Conn
+	wlock sync.Mutex
+}
+
 func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network uint64, stats string, ks *keystore.KeyStore, index []byte) (*faucet, error) {
 	// Assemble the raw devp2p protocol stack
 	stack, err := node.New(&node.Config{
@@ -321,13 +328,14 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 	defer conn.Close()
 
 	f.lock.Lock()
-	f.conns = append(f.conns, conn)
+	wsconn := &wsConn{conn: conn}
+	f.conns = append(f.conns, wsconn)
 	f.lock.Unlock()
 
 	defer func() {
 		f.lock.Lock()
 		for i, c := range f.conns {
-			if c == conn {
+			if c.conn == conn {
 				f.conns = append(f.conns[:i], f.conns[i+1:]...)
 				break
 			}
@@ -355,7 +363,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 		if head == nil || balance == nil {
 			// Report the faucet offline until initial stats are ready
 			//lint:ignore ST1005 This error is to be displayed in the browser
-			if err = sendError(conn, errors.New("Faucet offline")); err != nil {
+			if err = sendError(wsconn, errors.New("Faucet offline")); err != nil {
 				log.Warn("Failed to send faucet error to client", "err", err)
 				return
 			}
@@ -366,7 +374,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 	f.lock.RLock()
 	reqs := f.reqs
 	f.lock.RUnlock()
-	if err = send(conn, map[string]interface{}{
+	if err = send(wsconn, map[string]interface{}{
 		"funds":    new(big.Int).Div(balance, ether),
 		"funded":   nonce,
 		"peers":    f.stack.Server().PeerCount(),
@@ -375,7 +383,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 		log.Warn("Failed to send initial stats to client", "err", err)
 		return
 	}
-	if err = send(conn, head, 3*time.Second); err != nil {
+	if err = send(wsconn, head, 3*time.Second); err != nil {
 		log.Warn("Failed to send initial header to client", "err", err)
 		return
 	}
@@ -391,7 +399,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 		if !*noauthFlag && !strings.HasPrefix(msg.URL, "https://twitter.com/") && !strings.HasPrefix(msg.URL, "https://www.facebook.com/") {
-			if err = sendError(conn, errors.New("URL doesn't link to supported services")); err != nil {
+			if err = sendError(wsconn, errors.New("URL doesn't link to supported services")); err != nil {
 				log.Warn("Failed to send URL error to client", "err", err)
 				return
 			}
@@ -399,7 +407,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 		}
 		if msg.Tier >= uint(*tiersFlag) {
 			//lint:ignore ST1005 This error is to be displayed in the browser
-			if err = sendError(conn, errors.New("Invalid funding tier requested")); err != nil {
+			if err = sendError(wsconn, errors.New("Invalid funding tier requested")); err != nil {
 				log.Warn("Failed to send tier error to client", "err", err)
 				return
 			}
@@ -415,7 +423,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 
 			res, err := http.PostForm("https://www.google.com/recaptcha/api/siteverify", form)
 			if err != nil {
-				if err = sendError(conn, err); err != nil {
+				if err = sendError(wsconn, err); err != nil {
 					log.Warn("Failed to send captcha post error to client", "err", err)
 					return
 				}
@@ -428,7 +436,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 			err = json.NewDecoder(res.Body).Decode(&result)
 			res.Body.Close()
 			if err != nil {
-				if err = sendError(conn, err); err != nil {
+				if err = sendError(wsconn, err); err != nil {
 					log.Warn("Failed to send captcha decode error to client", "err", err)
 					return
 				}
@@ -437,7 +445,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 			if !result.Success {
 				log.Warn("Captcha verification failed", "err", string(result.Errors))
 				//lint:ignore ST1005 it's funny and the robot won't mind
-				if err = sendError(conn, errors.New("Beep-bop, you're a robot!")); err != nil {
+				if err = sendError(wsconn, errors.New("Beep-bop, you're a robot!")); err != nil {
 					log.Warn("Failed to send captcha failure to client", "err", err)
 					return
 				}
@@ -465,7 +473,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 			err = errors.New("Something funky happened, please open an issue at https://github.com/ethereum/go-ethereum/issues")
 		}
 		if err != nil {
-			if err = sendError(conn, err); err != nil {
+			if err = sendError(wsconn, err); err != nil {
 				log.Warn("Failed to send prefix error to client", "err", err)
 				return
 			}
@@ -489,7 +497,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 			signed, err := f.keystore.SignTx(f.account, tx, f.config.ChainID)
 			if err != nil {
 				f.lock.Unlock()
-				if err = sendError(conn, err); err != nil {
+				if err = sendError(wsconn, err); err != nil {
 					log.Warn("Failed to send transaction creation error to client", "err", err)
 					return
 				}
@@ -498,7 +506,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 			// Submit the transaction and mark as funded if successful
 			if err := f.client.SendTransaction(context.Background(), signed); err != nil {
 				f.lock.Unlock()
-				if err = sendError(conn, err); err != nil {
+				if err = sendError(wsconn, err); err != nil {
 					log.Warn("Failed to send transaction transmission error to client", "err", err)
 					return
 				}
@@ -520,13 +528,13 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) {
 
 		// Send an error if too frequent funding, othewise a success
 		if !fund {
-			if err = sendError(conn, fmt.Errorf("%s left until next allowance", common.PrettyDuration(time.Until(timeout)))); err != nil { // nolint: gosimple
+			if err = sendError(wsconn, fmt.Errorf("%s left until next allowance", common.PrettyDuration(time.Until(timeout)))); err != nil { // nolint: gosimple
 				log.Warn("Failed to send funding error to client", "err", err)
 				return
 			}
 			continue
 		}
-		if err = sendSuccess(conn, fmt.Sprintf("Funding request accepted for %s into %s", username, address.Hex())); err != nil {
+		if err = sendSuccess(wsconn, fmt.Sprintf("Funding request accepted for %s into %s", username, address.Hex())); err != nil {
 			log.Warn("Failed to send funding success to client", "err", err)
 			return
 		}
@@ -619,12 +627,12 @@ func (f *faucet) loop() {
 					"requests": f.reqs,
 				}, time.Second); err != nil {
 					log.Warn("Failed to send stats to client", "err", err)
-					conn.Close()
+					conn.conn.Close()
 					continue
 				}
 				if err := send(conn, head, time.Second); err != nil {
 					log.Warn("Failed to send header to client", "err", err)
-					conn.Close()
+					conn.conn.Close()
 				}
 			}
 			f.lock.RUnlock()
@@ -646,7 +654,7 @@ func (f *faucet) loop() {
 			for _, conn := range f.conns {
 				if err := send(conn, map[string]interface{}{"requests": f.reqs}, time.Second); err != nil {
 					log.Warn("Failed to send requests to client", "err", err)
-					conn.Close()
+					conn.conn.Close()
 				}
 			}
 			f.lock.RUnlock()
@@ -656,23 +664,26 @@ func (f *faucet) loop() {
 
 // sends transmits a data packet to the remote end of the websocket, but also
 // setting a write deadline to prevent waiting forever on the node.
-func send(conn *websocket.Conn, value interface{}, timeout time.Duration) error {
+func send(conn *wsConn, value interface{}, timeout time.Duration) error {
 	if timeout == 0 {
 		timeout = 60 * time.Second
 	}
-	conn.SetWriteDeadline(time.Now().Add(timeout))
-	return conn.WriteJSON(value)
+	conn.wlock.Lock()
+	defer conn.wlock.Unlock()
+
+	conn.conn.SetWriteDeadline(time.Now().Add(timeout))
+	return conn.conn.WriteJSON(value)
 }
 
 // sendError transmits an error to the remote end of the websocket, also setting
 // the write deadline to 1 second to prevent waiting forever.
-func sendError(conn *websocket.Conn, err error) error {
+func sendError(conn *wsConn, err error) error {
 	return send(conn, map[string]string{"error": err.Error()}, time.Second)
 }
 
 // sendSuccess transmits a success message to the remote end of the websocket, also
 // setting the write deadline to 1 second to prevent waiting forever.
-func sendSuccess(conn *websocket.Conn, msg string) error {
+func sendSuccess(conn *wsConn, msg string) error {
 	return send(conn, map[string]string{"success": msg}, time.Second)
 }