Browse Source

whisper: track active peers, add peer cache expiry test

Péter Szilágyi 10 years ago
parent
commit
e5e91e9eb3
2 changed files with 86 additions and 22 deletions
  1. 50 0
      whisper/peer_test.go
  2. 36 22
      whisper/whisper.go

+ 50 - 0
whisper/peer_test.go

@@ -190,3 +190,53 @@ func TestPeerDeliver(t *testing.T) {
 		t.Fatalf("repeating message arrived")
 	}
 }
+
+func TestPeerMessageExpiration(t *testing.T) {
+	// Start a tester and execute the handshake
+	tester, err := startTestPeerInited()
+	if err != nil {
+		t.Fatalf("failed to start initialized peer: %v", err)
+	}
+	defer tester.stream.Close()
+
+	// Fetch the peer instance for later inspection
+	tester.client.peerMu.RLock()
+	if peers := len(tester.client.peers); peers != 1 {
+		t.Fatalf("peer pool size mismatch: have %v, want %v", peers, 1)
+	}
+	var peer *peer
+	for peer, _ = range tester.client.peers {
+		break
+	}
+	tester.client.peerMu.RUnlock()
+
+	// Construct a message and pass it through the tester
+	message := NewMessage([]byte("peer test message"))
+	envelope, err := message.Wrap(DefaultPoW, Options{
+		TTL: time.Second,
+	})
+	if err != nil {
+		t.Fatalf("failed to wrap message: %v", err)
+	}
+	if err := tester.client.Send(envelope); err != nil {
+		t.Fatalf("failed to send message: %v", err)
+	}
+	payload := []interface{}{envelope}
+	if err := p2p.ExpectMsg(tester.stream, messagesCode, payload); err != nil {
+		t.Fatalf("message mismatch: %v", err)
+	}
+	// Check that the message is inside the cache
+	if !peer.known.Has(envelope.Hash()) {
+		t.Fatalf("message not found in cache")
+	}
+	// Discard messages until expiration and check cache again
+	exp := time.Now().Add(time.Second + expirationCycle)
+	for time.Now().Before(exp) {
+		if err := p2p.ExpectMsg(tester.stream, messagesCode, []interface{}{}); err != nil {
+			t.Fatalf("message mismatch: %v", err)
+		}
+	}
+	if peer.known.Has(envelope.Hash()) {
+		t.Fatalf("message not expired from cache")
+	}
+}

+ 36 - 22
whisper/whisper.go

@@ -46,22 +46,26 @@ type Whisper struct {
 	protocol p2p.Protocol
 	filters  *filter.Filters
 
-	mmu      sync.RWMutex              // Message mutex to sync the below pool
-	messages map[common.Hash]*Envelope // Pool of messages currently tracked by this node
-	expiry   map[uint32]*set.SetNonTS  // Message expiration pool (TODO: something lighter)
+	keys map[string]*ecdsa.PrivateKey
 
-	quit chan struct{}
+	messages    map[common.Hash]*Envelope // Pool of messages currently tracked by this node
+	expirations map[uint32]*set.SetNonTS  // Message expiration pool (TODO: something lighter)
+	poolMu      sync.RWMutex              // Mutex to sync the message and expiration pools
 
-	keys map[string]*ecdsa.PrivateKey
+	peers  map[*peer]struct{} // Set of currently active peers
+	peerMu sync.RWMutex       // Mutex to sync the active peer set
+
+	quit chan struct{}
 }
 
 func New() *Whisper {
 	whisper := &Whisper{
-		messages: make(map[common.Hash]*Envelope),
-		filters:  filter.New(),
-		expiry:   make(map[uint32]*set.SetNonTS),
-		quit:     make(chan struct{}),
-		keys:     make(map[string]*ecdsa.PrivateKey),
+		filters:     filter.New(),
+		keys:        make(map[string]*ecdsa.PrivateKey),
+		messages:    make(map[common.Hash]*Envelope),
+		expirations: make(map[uint32]*set.SetNonTS),
+		peers:       make(map[*peer]struct{}),
+		quit:        make(chan struct{}),
 	}
 	whisper.filters.Start()
 
@@ -179,6 +183,16 @@ func (self *Whisper) handlePeer(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
 	whisperPeer.start()
 	defer whisperPeer.stop()
 
+	// Start tracking the active peer
+	self.peerMu.Lock()
+	self.peers[whisperPeer] = struct{}{}
+	self.peerMu.Unlock()
+
+	defer func() {
+		self.peerMu.Lock()
+		delete(self.peers, whisperPeer)
+		self.peerMu.Unlock()
+	}()
 	// Read and process inbound messages directly to merge into client-global state
 	for {
 		// Fetch the next packet and decode the contained envelopes
@@ -206,8 +220,8 @@ func (self *Whisper) handlePeer(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
 // whisper network. It also inserts the envelope into the expiration pool at the
 // appropriate time-stamp.
 func (self *Whisper) add(envelope *Envelope) error {
-	self.mmu.Lock()
-	defer self.mmu.Unlock()
+	self.poolMu.Lock()
+	defer self.poolMu.Unlock()
 
 	// Insert the message into the tracked pool
 	hash := envelope.Hash()
@@ -218,11 +232,11 @@ func (self *Whisper) add(envelope *Envelope) error {
 	self.messages[hash] = envelope
 
 	// Insert the message into the expiration pool for later removal
-	if self.expiry[envelope.Expiry] == nil {
-		self.expiry[envelope.Expiry] = set.NewNonTS()
+	if self.expirations[envelope.Expiry] == nil {
+		self.expirations[envelope.Expiry] = set.NewNonTS()
 	}
-	if !self.expiry[envelope.Expiry].Has(hash) {
-		self.expiry[envelope.Expiry].Add(hash)
+	if !self.expirations[envelope.Expiry].Has(hash) {
+		self.expirations[envelope.Expiry].Add(hash)
 
 		// Notify the local node of a message arrival
 		go self.postEvent(envelope)
@@ -292,11 +306,11 @@ func (self *Whisper) update() {
 // expire iterates over all the expiration timestamps, removing all stale
 // messages from the pools.
 func (self *Whisper) expire() {
-	self.mmu.Lock()
-	defer self.mmu.Unlock()
+	self.poolMu.Lock()
+	defer self.poolMu.Unlock()
 
 	now := uint32(time.Now().Unix())
-	for then, hashSet := range self.expiry {
+	for then, hashSet := range self.expirations {
 		// Short circuit if a future time
 		if then > now {
 			continue
@@ -306,14 +320,14 @@ func (self *Whisper) expire() {
 			delete(self.messages, v.(common.Hash))
 			return true
 		})
-		self.expiry[then].Clear()
+		self.expirations[then].Clear()
 	}
 }
 
 // envelopes retrieves all the messages currently pooled by the node.
 func (self *Whisper) envelopes() []*Envelope {
-	self.mmu.RLock()
-	defer self.mmu.RUnlock()
+	self.poolMu.RLock()
+	defer self.poolMu.RUnlock()
 
 	envelopes := make([]*Envelope, 0, len(self.messages))
 	for _, envelope := range self.messages {