Browse Source

p2p: track write errors and prevent writes during shutdown

As of this commit, we no longer rely on the protocol handler to report
write errors in a timely fashion. When a write fails, shutdown is
initiated immediately and no new writes can start. This will also
prevent new writes from starting after Server.Stop has been called.
Felix Lange 10 years ago
parent
commit
8dcbdcad0a
1 changed files with 57 additions and 25 deletions
  1. 57 25
      p2p/peer.go

+ 57 - 25
p2p/peer.go

@@ -115,37 +115,54 @@ func newPeer(conn *conn, protocols []Protocol) *Peer {
 }
 
 func (p *Peer) run() DiscReason {
-	readErr := make(chan error, 1)
+	var (
+		writeStart = make(chan struct{}, 1)
+		writeErr   = make(chan error, 1)
+		readErr    = make(chan error, 1)
+		reason     DiscReason
+		requested  bool
+	)
 	p.wg.Add(2)
 	go p.readLoop(readErr)
 	go p.pingLoop()
 
-	p.startProtocols()
+	// Start all protocol handlers.
+	writeStart <- struct{}{}
+	p.startProtocols(writeStart, writeErr)
 
 	// Wait for an error or disconnect.
-	var (
-		reason    DiscReason
-		requested bool
-	)
-	select {
-	case err := <-readErr:
-		if r, ok := err.(DiscReason); ok {
-			reason = r
-		} else {
-			// Note: We rely on protocols to abort if there is a write
-			// error. It might be more robust to handle them here as well.
-			glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
-			reason = DiscNetworkError
+loop:
+	for {
+		select {
+		case err := <-writeErr:
+			// A write finished. Allow the next write to start if
+			// there was no error.
+			if err != nil {
+				glog.V(logger.Detail).Infof("%v: Write error: %v\n", p, err)
+				reason = DiscNetworkError
+				break loop
+			}
+			writeStart <- struct{}{}
+		case err := <-readErr:
+			if r, ok := err.(DiscReason); ok {
+				reason = r
+			} else {
+				glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
+				reason = DiscNetworkError
+			}
+			break loop
+		case err := <-p.protoErr:
+			reason = discReasonForError(err)
+			break loop
+		case reason = <-p.disc:
+			requested = true
+			break loop
 		}
-	case err := <-p.protoErr:
-		reason = discReasonForError(err)
-	case reason = <-p.disc:
-		requested = true
 	}
+
 	close(p.closed)
 	p.rw.close(reason)
 	p.wg.Wait()
-
 	if requested {
 		reason = DiscRequested
 	}
@@ -247,11 +264,13 @@ outer:
 	return result
 }
 
-func (p *Peer) startProtocols() {
+func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
 	p.wg.Add(len(p.running))
 	for _, proto := range p.running {
 		proto := proto
 		proto.closed = p.closed
+		proto.wstart = writeStart
+		proto.werr = writeErr
 		glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
 		go func() {
 			err := proto.Run(p, proto)
@@ -280,18 +299,31 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
 
 type protoRW struct {
 	Protocol
-	in     chan Msg
-	closed <-chan struct{}
+	in     chan Msg        // receices read messages
+	closed <-chan struct{} // receives when peer is shutting down
+	wstart <-chan struct{} // receives when write may start
+	werr   chan<- error    // for write results
 	offset uint64
 	w      MsgWriter
 }
 
-func (rw *protoRW) WriteMsg(msg Msg) error {
+func (rw *protoRW) WriteMsg(msg Msg) (err error) {
 	if msg.Code >= rw.Length {
 		return newPeerError(errInvalidMsgCode, "not handled")
 	}
 	msg.Code += rw.offset
-	return rw.w.WriteMsg(msg)
+	select {
+	case <-rw.wstart:
+		err = rw.w.WriteMsg(msg)
+		// Report write status back to Peer.run. It will initiate
+		// shutdown if the error is non-nil and unblock the next write
+		// otherwise. The calling protocol code should exit for errors
+		// as well but we don't want to rely on that.
+		rw.werr <- err
+	case <-rw.closed:
+		err = fmt.Errorf("shutting down")
+	}
+	return err
 }
 
 func (rw *protoRW) ReadMsg() (Msg, error) {