Parcourir la source

p2p: track write errors and prevent writes during shutdown

As of this commit, we no longer rely on the protocol handler to report
write errors in a timely fashion. When a write fails, shutdown is
initiated immediately and no new writes can start. This will also
prevent new writes from starting after Server.Stop has been called.
Felix Lange il y a 10 ans
Parent
commit
8dcbdcad0a
1 fichiers modifiés avec 57 ajouts et 25 suppressions
  1. 57 25
      p2p/peer.go

+ 57 - 25
p2p/peer.go

@@ -115,37 +115,54 @@ func newPeer(conn *conn, protocols []Protocol) *Peer {
 }
 
 func (p *Peer) run() DiscReason {
-	readErr := make(chan error, 1)
+	var (
+		writeStart = make(chan struct{}, 1)
+		writeErr   = make(chan error, 1)
+		readErr    = make(chan error, 1)
+		reason     DiscReason
+		requested  bool
+	)
 	p.wg.Add(2)
 	go p.readLoop(readErr)
 	go p.pingLoop()
 
-	p.startProtocols()
+	// Start all protocol handlers.
+	writeStart <- struct{}{}
+	p.startProtocols(writeStart, writeErr)
 
 	// Wait for an error or disconnect.
-	var (
-		reason    DiscReason
-		requested bool
-	)
-	select {
-	case err := <-readErr:
-		if r, ok := err.(DiscReason); ok {
-			reason = r
-		} else {
-			// Note: We rely on protocols to abort if there is a write
-			// error. It might be more robust to handle them here as well.
-			glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
-			reason = DiscNetworkError
+loop:
+	for {
+		select {
+		case err := <-writeErr:
+			// A write finished. Allow the next write to start if
+			// there was no error.
+			if err != nil {
+				glog.V(logger.Detail).Infof("%v: Write error: %v\n", p, err)
+				reason = DiscNetworkError
+				break loop
+			}
+			writeStart <- struct{}{}
+		case err := <-readErr:
+			if r, ok := err.(DiscReason); ok {
+				reason = r
+			} else {
+				glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
+				reason = DiscNetworkError
+			}
+			break loop
+		case err := <-p.protoErr:
+			reason = discReasonForError(err)
+			break loop
+		case reason = <-p.disc:
+			requested = true
+			break loop
 		}
-	case err := <-p.protoErr:
-		reason = discReasonForError(err)
-	case reason = <-p.disc:
-		requested = true
 	}
+
 	close(p.closed)
 	p.rw.close(reason)
 	p.wg.Wait()
-
 	if requested {
 		reason = DiscRequested
 	}
@@ -247,11 +264,13 @@ outer:
 	return result
 }
 
-func (p *Peer) startProtocols() {
+func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
 	p.wg.Add(len(p.running))
 	for _, proto := range p.running {
 		proto := proto
 		proto.closed = p.closed
+		proto.wstart = writeStart
+		proto.werr = writeErr
 		glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
 		go func() {
 			err := proto.Run(p, proto)
@@ -280,18 +299,31 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
 
 type protoRW struct {
 	Protocol
-	in     chan Msg
-	closed <-chan struct{}
+	in     chan Msg        // receices read messages
+	closed <-chan struct{} // receives when peer is shutting down
+	wstart <-chan struct{} // receives when write may start
+	werr   chan<- error    // for write results
 	offset uint64
 	w      MsgWriter
 }
 
-func (rw *protoRW) WriteMsg(msg Msg) error {
+func (rw *protoRW) WriteMsg(msg Msg) (err error) {
 	if msg.Code >= rw.Length {
 		return newPeerError(errInvalidMsgCode, "not handled")
 	}
 	msg.Code += rw.offset
-	return rw.w.WriteMsg(msg)
+	select {
+	case <-rw.wstart:
+		err = rw.w.WriteMsg(msg)
+		// Report write status back to Peer.run. It will initiate
+		// shutdown if the error is non-nil and unblock the next write
+		// otherwise. The calling protocol code should exit for errors
+		// as well but we don't want to rely on that.
+		rw.werr <- err
+	case <-rw.closed:
+		err = fmt.Errorf("shutting down")
+	}
+	return err
 }
 
 func (rw *protoRW) ReadMsg() (Msg, error) {