Browse Source

p2p: improve and test eofSignal

Felix Lange 11 years ago
parent
commit
e28c60caf9
2 changed files with 68 additions and 5 deletions
  1. 12 5
      p2p/peer.go
  2. 56 0
      p2p/peer_test.go

+ 12 - 5
p2p/peer.go

@@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error)
 		proto.in <- msg
 	} else {
 		wait = true
-		pr := &eofSignal{msg.Payload, protoDone}
+		pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
 		msg.Payload = pr
 		proto.in <- msg
 	}
@@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) {
 	return msg, nil
 }
 
-// eofSignal wraps a reader with eof signaling.
-// the eof channel is closed when the wrapped reader
-// reaches EOF.
+// eofSignal wraps a reader with eof signaling. the eof channel is
+// closed when the wrapped reader returns an error or when count bytes
+// have been read.
+//
 type eofSignal struct {
 	wrapped io.Reader
+	count   int64
 	eof     chan<- struct{}
 }
 
+// note: when using eofSignal to detect whether a message payload
+// has been read, Read might not be called for zero sized messages.
+
 func (r *eofSignal) Read(buf []byte) (int, error) {
 	n, err := r.wrapped.Read(buf)
-	if err != nil {
+	r.count -= int64(n)
+	if (err != nil || r.count <= 0) && r.eof != nil {
 		r.eof <- struct{}{} // tell Peer that msg has been consumed
+		r.eof = nil
 	}
 	return n, err
 }

+ 56 - 0
p2p/peer_test.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/hex"
+	"io"
 	"io/ioutil"
 	"net"
 	"reflect"
@@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) {
 	// Should not hang.
 	p.Disconnect(DiscAlreadyConnected)
 }
+
+func TestEOFSignal(t *testing.T) {
+	rb := make([]byte, 10)
+
+	// empty reader
+	eof := make(chan struct{}, 1)
+	sig := &eofSignal{new(bytes.Buffer), 0, eof}
+	if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+	}
+	select {
+	case <-eof:
+	default:
+		t.Error("EOF chan not signaled")
+	}
+
+	// count before error
+	eof = make(chan struct{}, 1)
+	sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
+	if n, err := sig.Read(rb); n != 8 || err != nil {
+		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+	}
+	select {
+	case <-eof:
+	default:
+		t.Error("EOF chan not signaled")
+	}
+
+	// error before count
+	eof = make(chan struct{}, 1)
+	sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
+	if n, err := sig.Read(rb); n != 4 || err != nil {
+		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+	}
+	if n, err := sig.Read(rb); n != 0 || err != io.EOF {
+		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+	}
+	select {
+	case <-eof:
+	default:
+		t.Error("EOF chan not signaled")
+	}
+
+	// no signal if neither occurs
+	eof = make(chan struct{}, 1)
+	sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
+	if n, err := sig.Read(rb); n != 10 || err != nil {
+		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
+	}
+	select {
+	case <-eof:
+		t.Error("unexpected EOF signal")
+	default:
+	}
+}