瀏覽代碼

core, eth, les: fix messy code (#15367)

* core, eth, les: fix messy code

* les: fixed tx status test and rlp encoding

* core: add a workaround for light sync
Péter Szilágyi 8 年之前
父節點
當前提交
0095531a58
共有 11 個文件被更改,包括 175 次插入203 次删除
  1. 24 43
      core/bloombits/matcher.go
  2. 25 24
      core/chain_indexer.go
  3. 1 1
      core/chain_indexer_test.go
  4. 5 5
      core/tx_list.go
  5. 44 75
      core/tx_pool.go
  6. 2 3
      eth/helper_test.go
  7. 1 1
      eth/protocol.go
  8. 35 30
      les/handler.go
  9. 29 18
      les/handler_test.go
  10. 2 3
      les/peer.go
  11. 7 0
      les/protocol.go

+ 24 - 43
core/bloombits/matcher.go

@@ -57,12 +57,16 @@ type partialMatches struct {
 // Retrieval represents a request for retrieval task assignments for a given
 // bit with the given number of fetch elements, or a response for such a request.
 // It can also have the actual results set to be used as a delivery data struct.
+//
+// The contest and error fields are used by the light client to terminate matching
+// early if an error is enountered on some path of the pipeline.
 type Retrieval struct {
 	Bit      uint
 	Sections []uint64
 	Bitsets  [][]byte
-	Error    error
-	Context  context.Context
+
+	Context context.Context
+	Error   error
 }
 
 // Matcher is a pipelined system of schedulers and logic matchers which perform
@@ -506,54 +510,31 @@ func (m *Matcher) distributor(dist chan *request, session *MatcherSession) {
 type MatcherSession struct {
 	matcher *Matcher
 
-	quit     chan struct{} // Quit channel to request pipeline termination
-	kill     chan struct{} // Term channel to signal non-graceful forced shutdown
-	ctx      context.Context
-	err      error
-	stopping bool
-	lock     sync.Mutex
-	pend     sync.WaitGroup
+	closer sync.Once     // Sync object to ensure we only ever close once
+	quit   chan struct{} // Quit channel to request pipeline termination
+	kill   chan struct{} // Term channel to signal non-graceful forced shutdown
+
+	ctx context.Context // Context used by the light client to abort filtering
+	err atomic.Value    // Global error to track retrieval failures deep in the chain
+
+	pend sync.WaitGroup
 }
 
 // Close stops the matching process and waits for all subprocesses to terminate
 // before returning. The timeout may be used for graceful shutdown, allowing the
 // currently running retrievals to complete before this time.
 func (s *MatcherSession) Close() {
-	s.lock.Lock()
-	stopping := s.stopping
-	s.stopping = true
-	s.lock.Unlock()
-	// ensure that we only close the session once
-	if stopping {
-		return
-	}
-
-	// Bail out if the matcher is not running
-	select {
-	case <-s.quit:
-		return
-	default:
-	}
-	// Signal termination and wait for all goroutines to tear down
-	close(s.quit)
-	time.AfterFunc(time.Second, func() { close(s.kill) })
-	s.pend.Wait()
+	s.closer.Do(func() {
+		// Signal termination and wait for all goroutines to tear down
+		close(s.quit)
+		time.AfterFunc(time.Second, func() { close(s.kill) })
+		s.pend.Wait()
+	})
 }
 
-// setError sets an error and stops the session
-func (s *MatcherSession) setError(err error) {
-	s.lock.Lock()
-	s.err = err
-	s.lock.Unlock()
-	s.Close()
-}
-
-// Error returns an error if one has happened during the session
+// Error returns any failure encountered during the matching session.
 func (s *MatcherSession) Error() error {
-	s.lock.Lock()
-	defer s.lock.Unlock()
-
-	return s.err
+	return s.err.Load().(error)
 }
 
 // AllocateRetrieval assigns a bloom bit index to a client process that can either
@@ -655,9 +636,9 @@ func (s *MatcherSession) Multiplex(batch int, wait time.Duration, mux chan chan
 
 			result := <-request
 			if result.Error != nil {
-				s.setError(result.Error)
+				s.err.Store(result.Error)
+				s.Close()
 			}
-
 			s.DeliverSections(result.Bit, result.Sections, result.Bitsets)
 		}
 	}

+ 25 - 24
core/chain_indexer.go

@@ -36,17 +36,25 @@ import (
 type ChainIndexerBackend interface {
 	// Reset initiates the processing of a new chain segment, potentially terminating
 	// any partially completed operations (in case of a reorg).
-	Reset(section uint64, lastSectionHead common.Hash) error
+	Reset(section uint64, prevHead common.Hash) error
 
 	// Process crunches through the next header in the chain segment. The caller
 	// will ensure a sequential order of headers.
 	Process(header *types.Header)
 
-	// Commit finalizes the section metadata and stores it into the database. This
-	// interface will usually be a batch writer.
+	// Commit finalizes the section metadata and stores it into the database.
 	Commit() error
 }
 
+// ChainIndexerChain interface is used for connecting the indexer to a blockchain
+type ChainIndexerChain interface {
+	// CurrentHeader retrieves the latest locally known header.
+	CurrentHeader() *types.Header
+
+	// SubscribeChainEvent subscribes to new head header notifications.
+	SubscribeChainEvent(ch chan<- ChainEvent) event.Subscription
+}
+
 // ChainIndexer does a post-processing job for equally sized sections of the
 // canonical chain (like BlooomBits and CHT structures). A ChainIndexer is
 // connected to the blockchain through the event system by starting a
@@ -114,21 +122,14 @@ func (c *ChainIndexer) AddKnownSectionHead(section uint64, shead common.Hash) {
 	c.setValidSections(section + 1)
 }
 
-// IndexerChain interface is used for connecting the indexer to a blockchain
-type IndexerChain interface {
-	CurrentHeader() *types.Header
-	SubscribeChainEvent(ch chan<- ChainEvent) event.Subscription
-}
-
 // Start creates a goroutine to feed chain head events into the indexer for
 // cascading background processing. Children do not need to be started, they
 // are notified about new events by their parents.
-func (c *ChainIndexer) Start(chain IndexerChain) {
-	ch := make(chan ChainEvent, 10)
-	sub := chain.SubscribeChainEvent(ch)
-	currentHeader := chain.CurrentHeader()
+func (c *ChainIndexer) Start(chain ChainIndexerChain) {
+	events := make(chan ChainEvent, 10)
+	sub := chain.SubscribeChainEvent(events)
 
-	go c.eventLoop(currentHeader, ch, sub)
+	go c.eventLoop(chain.CurrentHeader(), events, sub)
 }
 
 // Close tears down all goroutines belonging to the indexer and returns any error
@@ -149,14 +150,12 @@ func (c *ChainIndexer) Close() error {
 			errs = append(errs, err)
 		}
 	}
-
 	// Close all children
 	for _, child := range c.children {
 		if err := child.Close(); err != nil {
 			errs = append(errs, err)
 		}
 	}
-
 	// Return any failures
 	switch {
 	case len(errs) == 0:
@@ -173,7 +172,7 @@ func (c *ChainIndexer) Close() error {
 // eventLoop is a secondary - optional - event loop of the indexer which is only
 // started for the outermost indexer to push chain head events into a processing
 // queue.
-func (c *ChainIndexer) eventLoop(currentHeader *types.Header, ch chan ChainEvent, sub event.Subscription) {
+func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainEvent, sub event.Subscription) {
 	// Mark the chain indexer as active, requiring an additional teardown
 	atomic.StoreUint32(&c.active, 1)
 
@@ -193,7 +192,7 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, ch chan ChainEvent
 			errc <- nil
 			return
 
-		case ev, ok := <-ch:
+		case ev, ok := <-events:
 			// Received a new event, ensure it's not nil (closing) and update
 			if !ok {
 				errc := <-c.quit
@@ -202,6 +201,8 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, ch chan ChainEvent
 			}
 			header := ev.Block.Header()
 			if header.ParentHash != prevHash {
+				// Reorg to the common ancestor (might not exist in light sync mode, skip reorg then)
+				// TODO(karalabe, zsfelfoldi): This seems a bit brittle, can we detect this case explicitly?
 				if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil {
 					c.newHead(h.Number.Uint64(), true)
 				}
@@ -259,8 +260,8 @@ func (c *ChainIndexer) newHead(head uint64, reorg bool) {
 // down into the processing backend.
 func (c *ChainIndexer) updateLoop() {
 	var (
-		updated   time.Time
-		updateMsg bool
+		updating bool
+		updated  time.Time
 	)
 
 	for {
@@ -277,7 +278,7 @@ func (c *ChainIndexer) updateLoop() {
 				// Periodically print an upgrade log message to the user
 				if time.Since(updated) > 8*time.Second {
 					if c.knownSections > c.storedSections+1 {
-						updateMsg = true
+						updating = true
 						c.log.Info("Upgrading chain index", "percentage", c.storedSections*100/c.knownSections)
 					}
 					updated = time.Now()
@@ -300,8 +301,8 @@ func (c *ChainIndexer) updateLoop() {
 				if err == nil && oldHead == c.SectionHead(section-1) {
 					c.setSectionHead(section, newHead)
 					c.setValidSections(section + 1)
-					if c.storedSections == c.knownSections && updateMsg {
-						updateMsg = false
+					if c.storedSections == c.knownSections && updating {
+						updating = false
 						c.log.Info("Finished upgrading chain index")
 					}
 
@@ -412,7 +413,7 @@ func (c *ChainIndexer) setValidSections(sections uint64) {
 	c.storedSections = sections // needed if new > old
 }
 
-// sectionHead retrieves the last block hash of a processed section from the
+// SectionHead retrieves the last block hash of a processed section from the
 // index database.
 func (c *ChainIndexer) SectionHead(section uint64) common.Hash {
 	var data [8]byte

+ 1 - 1
core/chain_indexer_test.go

@@ -209,7 +209,7 @@ func (b *testChainIndexBackend) reorg(headNum uint64) uint64 {
 	return b.stored * b.indexer.sectionSize
 }
 
-func (b *testChainIndexBackend) Reset(section uint64, lastSectionHead common.Hash) error {
+func (b *testChainIndexBackend) Reset(section uint64, prevHead common.Hash) error {
 	b.section = section
 	b.headerCnt = 0
 	return nil

+ 5 - 5
core/tx_list.go

@@ -384,13 +384,13 @@ func (h *priceHeap) Pop() interface{} {
 // txPricedList is a price-sorted heap to allow operating on transactions pool
 // contents in a price-incrementing way.
 type txPricedList struct {
-	all    *map[common.Hash]txLookupRec // Pointer to the map of all transactions
-	items  *priceHeap                   // Heap of prices of all the stored transactions
-	stales int                          // Number of stale price points to (re-heap trigger)
+	all    *map[common.Hash]*types.Transaction // Pointer to the map of all transactions
+	items  *priceHeap                          // Heap of prices of all the stored transactions
+	stales int                                 // Number of stale price points to (re-heap trigger)
 }
 
 // newTxPricedList creates a new price-sorted transaction heap.
-func newTxPricedList(all *map[common.Hash]txLookupRec) *txPricedList {
+func newTxPricedList(all *map[common.Hash]*types.Transaction) *txPricedList {
 	return &txPricedList{
 		all:   all,
 		items: new(priceHeap),
@@ -416,7 +416,7 @@ func (l *txPricedList) Removed() {
 
 	l.stales, l.items = 0, &reheap
 	for _, tx := range *l.all {
-		*l.items = append(*l.items, tx.tx)
+		*l.items = append(*l.items, tx)
 	}
 	heap.Init(l.items)
 }

+ 44 - 75
core/tx_pool.go

@@ -103,6 +103,16 @@ var (
 	underpricedTxCounter = metrics.NewCounter("txpool/underpriced")
 )
 
+// TxStatus is the current status of a transaction as seen py the pool.
+type TxStatus uint
+
+const (
+	TxStatusUnknown TxStatus = iota
+	TxStatusQueued
+	TxStatusPending
+	TxStatusIncluded
+)
+
 // blockChain provides the state of blockchain and current gas limit to do
 // some pre checks in tx pool and event subscribers.
 type blockChain interface {
@@ -192,22 +202,17 @@ type TxPool struct {
 	locals  *accountSet // Set of local transaction to exepmt from evicion rules
 	journal *txJournal  // Journal of local transaction to back up to disk
 
-	pending map[common.Address]*txList   // All currently processable transactions
-	queue   map[common.Address]*txList   // Queued but non-processable transactions
-	beats   map[common.Address]time.Time // Last heartbeat from each known account
-	all     map[common.Hash]txLookupRec  // All transactions to allow lookups
-	priced  *txPricedList                // All transactions sorted by price
+	pending map[common.Address]*txList         // All currently processable transactions
+	queue   map[common.Address]*txList         // Queued but non-processable transactions
+	beats   map[common.Address]time.Time       // Last heartbeat from each known account
+	all     map[common.Hash]*types.Transaction // All transactions to allow lookups
+	priced  *txPricedList                      // All transactions sorted by price
 
 	wg sync.WaitGroup // for shutdown sync
 
 	homestead bool
 }
 
-type txLookupRec struct {
-	tx      *types.Transaction
-	pending bool
-}
-
 // NewTxPool creates a new transaction pool to gather, sort and filter inbound
 // trnsactions from the network.
 func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain blockChain) *TxPool {
@@ -223,7 +228,7 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block
 		pending:     make(map[common.Address]*txList),
 		queue:       make(map[common.Address]*txList),
 		beats:       make(map[common.Address]time.Time),
-		all:         make(map[common.Hash]txLookupRec),
+		all:         make(map[common.Hash]*types.Transaction),
 		chainHeadCh: make(chan ChainHeadEvent, chainHeadChanSize),
 		gasPrice:    new(big.Int).SetUint64(config.PriceLimit),
 	}
@@ -599,7 +604,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error {
 func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) {
 	// If the transaction is already known, discard it
 	hash := tx.Hash()
-	if _, ok := pool.all[hash]; ok {
+	if pool.all[hash] != nil {
 		log.Trace("Discarding already known transaction", "hash", hash)
 		return false, fmt.Errorf("known transaction: %x", hash)
 	}
@@ -640,7 +645,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) {
 			pool.priced.Removed()
 			pendingReplaceCounter.Inc(1)
 		}
-		pool.all[tx.Hash()] = txLookupRec{tx, false}
+		pool.all[tx.Hash()] = tx
 		pool.priced.Put(tx)
 		pool.journalTx(from, tx)
 
@@ -687,7 +692,7 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er
 		pool.priced.Removed()
 		queuedReplaceCounter.Inc(1)
 	}
-	pool.all[hash] = txLookupRec{tx, false}
+	pool.all[hash] = tx
 	pool.priced.Put(tx)
 	return old != nil, nil
 }
@@ -730,13 +735,10 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T
 
 		pendingReplaceCounter.Inc(1)
 	}
-	if pool.all[hash].tx == nil {
-		// Failsafe to work around direct pending inserts (tests)
-		pool.all[hash] = txLookupRec{tx, true}
+	// Failsafe to work around direct pending inserts (tests)
+	if pool.all[hash] == nil {
+		pool.all[hash] = tx
 		pool.priced.Put(tx)
-	} else {
-		// set pending flag to true
-		pool.all[hash] = txLookupRec{tx, true}
 	}
 	// Set the potentially new pending nonce and notify any subsystems of the new tx
 	pool.beats[addr] = time.Now()
@@ -762,17 +764,15 @@ func (pool *TxPool) AddRemote(tx *types.Transaction) error {
 // AddLocals enqueues a batch of transactions into the pool if they are valid,
 // marking the senders as a local ones in the mean time, ensuring they go around
 // the local pricing constraints.
-func (pool *TxPool) AddLocals(txs []*types.Transaction) error {
-	pool.addTxs(txs, !pool.config.NoLocals)
-	return nil
+func (pool *TxPool) AddLocals(txs []*types.Transaction) []error {
+	return pool.addTxs(txs, !pool.config.NoLocals)
 }
 
 // AddRemotes enqueues a batch of transactions into the pool if they are valid.
 // If the senders are not among the locally tracked ones, full pricing constraints
 // will apply.
-func (pool *TxPool) AddRemotes(txs []*types.Transaction) error {
-	pool.addTxs(txs, false)
-	return nil
+func (pool *TxPool) AddRemotes(txs []*types.Transaction) []error {
+	return pool.addTxs(txs, false)
 }
 
 // addTx enqueues a single transaction into the pool if it is valid.
@@ -806,10 +806,11 @@ func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) []error {
 func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) []error {
 	// Add the batch of transaction, tracking the accepted ones
 	dirty := make(map[common.Address]struct{})
-	txErr := make([]error, len(txs))
+	errs := make([]error, len(txs))
+
 	for i, tx := range txs {
 		var replace bool
-		if replace, txErr[i] = pool.add(tx, local); txErr[i] == nil {
+		if replace, errs[i] = pool.add(tx, local); errs[i] == nil {
 			if !replace {
 				from, _ := types.Sender(pool.signer, tx) // already validated
 				dirty[from] = struct{}{}
@@ -824,54 +825,23 @@ func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) []error {
 		}
 		pool.promoteExecutables(addrs)
 	}
-	return txErr
-}
-
-// TxStatusData is returned by AddOrGetTxStatus for each transaction
-type TxStatusData struct {
-	Status uint
-	Data   []byte
+	return errs
 }
 
-const (
-	TxStatusUnknown = iota
-	TxStatusQueued
-	TxStatusPending
-	TxStatusIncluded // Data contains a TxChainPos struct
-	TxStatusError    // Data contains the error string
-)
-
-// AddOrGetTxStatus returns the status (unknown/pending/queued) of a batch of transactions
-// identified by their hashes in txHashes. Optionally the transactions themselves can be
-// passed too in txs, in which case the function will try adding the previously unknown ones
-// to the pool. If a new transaction cannot be added, TxStatusError is returned. Adding already
-// known transactions will return their previous status.
-// If txs is specified, txHashes is still required and has to match the transactions in txs.
-
-// Note: TxStatusIncluded is never returned by this function since the pool does not track
-// mined transactions. Included status can be checked by the caller (as it happens in the
-// LES protocol manager)
-func (pool *TxPool) AddOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []TxStatusData {
-	status := make([]TxStatusData, len(txHashes))
-	if txs != nil {
-		if len(txs) != len(txHashes) {
-			panic(nil)
-		}
-		txErr := pool.addTxs(txs, false)
-		for i, err := range txErr {
-			if err != nil {
-				status[i] = TxStatusData{TxStatusError, ([]byte)(err.Error())}
-			}
-		}
-	}
+// Status returns the status (unknown/pending/queued) of a batch of transactions
+// identified by their hashes.
+func (pool *TxPool) Status(hashes []common.Hash) []TxStatus {
+	pool.mu.RLock()
+	defer pool.mu.RUnlock()
 
-	for i, hash := range txHashes {
-		r, ok := pool.all[hash]
-		if ok {
-			if r.pending {
-				status[i] = TxStatusData{TxStatusPending, nil}
+	status := make([]TxStatus, len(hashes))
+	for i, hash := range hashes {
+		if tx := pool.all[hash]; tx != nil {
+			from, _ := types.Sender(pool.signer, tx) // already validated
+			if pool.pending[from].txs.items[tx.Nonce()] != nil {
+				status[i] = TxStatusPending
 			} else {
-				status[i] = TxStatusData{TxStatusQueued, nil}
+				status[i] = TxStatusQueued
 			}
 		}
 	}
@@ -884,18 +854,17 @@ func (pool *TxPool) Get(hash common.Hash) *types.Transaction {
 	pool.mu.RLock()
 	defer pool.mu.RUnlock()
 
-	return pool.all[hash].tx
+	return pool.all[hash]
 }
 
 // removeTx removes a single transaction from the queue, moving all subsequent
 // transactions back to the future queue.
 func (pool *TxPool) removeTx(hash common.Hash) {
 	// Fetch the transaction we wish to delete
-	txl, ok := pool.all[hash]
+	tx, ok := pool.all[hash]
 	if !ok {
 		return
 	}
-	tx := txl.tx
 	addr, _ := types.Sender(pool.signer, tx) // already validated during insertion
 
 	// Remove it from the list of known transactions

+ 2 - 3
eth/helper_test.go

@@ -97,7 +97,7 @@ type testTxPool struct {
 
 // AddRemotes appends a batch of transactions to the pool, and notifies any
 // listeners if the addition channel is non nil
-func (p *testTxPool) AddRemotes(txs []*types.Transaction) error {
+func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error {
 	p.lock.Lock()
 	defer p.lock.Unlock()
 
@@ -105,8 +105,7 @@ func (p *testTxPool) AddRemotes(txs []*types.Transaction) error {
 	if p.added != nil {
 		p.added <- txs
 	}
-
-	return nil
+	return make([]error, len(txs))
 }
 
 // Pending returns all the transactions known to the pool

+ 1 - 1
eth/protocol.go

@@ -97,7 +97,7 @@ var errorToString = map[int]string{
 
 type txPool interface {
 	// AddRemotes should add the given transactions to the pool.
-	AddRemotes([]*types.Transaction) error
+	AddRemotes([]*types.Transaction) []error
 
 	// Pending should return pending transactions.
 	// The slice should be modifiable by the caller.

+ 35 - 30
les/handler.go

@@ -89,7 +89,8 @@ type BlockChain interface {
 }
 
 type txPool interface {
-	AddOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []core.TxStatusData
+	AddRemotes(txs []*types.Transaction) []error
+	Status(hashes []common.Hash) []core.TxStatus
 }
 
 type ProtocolManager struct {
@@ -983,12 +984,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		if reject(uint64(reqCnt), MaxTxSend) {
 			return errResp(ErrRequestRejected, "")
 		}
-
-		txHashes := make([]common.Hash, len(txs))
-		for i, tx := range txs {
-			txHashes[i] = tx.Hash()
-		}
-		pm.addOrGetTxStatus(txs, txHashes)
+		pm.txpool.AddRemotes(txs)
 
 		_, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
 		pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
@@ -1010,16 +1006,25 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			return errResp(ErrRequestRejected, "")
 		}
 
-		txHashes := make([]common.Hash, len(req.Txs))
+		hashes := make([]common.Hash, len(req.Txs))
 		for i, tx := range req.Txs {
-			txHashes[i] = tx.Hash()
+			hashes[i] = tx.Hash()
+		}
+		stats := pm.txStatus(hashes)
+		for i, stat := range stats {
+			if stat.Status == core.TxStatusUnknown {
+				if errs := pm.txpool.AddRemotes([]*types.Transaction{req.Txs[i]}); errs[0] != nil {
+					stats[i].Error = errs[0]
+					continue
+				}
+				stats[i] = pm.txStatus([]common.Hash{hashes[i]})[0]
+			}
 		}
-
-		res := pm.addOrGetTxStatus(req.Txs, txHashes)
 
 		bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
 		pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
-		return p.SendTxStatus(req.ReqID, bv, res)
+
+		return p.SendTxStatus(req.ReqID, bv, stats)
 
 	case GetTxStatusMsg:
 		if pm.txpool == nil {
@@ -1027,22 +1032,20 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		}
 		// Transactions arrived, parse all of them and deliver to the pool
 		var req struct {
-			ReqID    uint64
-			TxHashes []common.Hash
+			ReqID  uint64
+			Hashes []common.Hash
 		}
 		if err := msg.Decode(&req); err != nil {
 			return errResp(ErrDecode, "msg %v: %v", msg, err)
 		}
-		reqCnt := len(req.TxHashes)
+		reqCnt := len(req.Hashes)
 		if reject(uint64(reqCnt), MaxTxStatus) {
 			return errResp(ErrRequestRejected, "")
 		}
-
-		res := pm.addOrGetTxStatus(nil, req.TxHashes)
-
 		bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
 		pm.server.fcCostStats.update(msg.Code, uint64(reqCnt), rcost)
-		return p.SendTxStatus(req.ReqID, bv, res)
+
+		return p.SendTxStatus(req.ReqID, bv, pm.txStatus(req.Hashes))
 
 	case TxStatusMsg:
 		if pm.odr == nil {
@@ -1052,7 +1055,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		p.Log().Trace("Received tx status response")
 		var resp struct {
 			ReqID, BV uint64
-			Status    []core.TxStatusData
+			Status    []core.TxStatus
 		}
 		if err := msg.Decode(&resp); err != nil {
 			return errResp(ErrDecode, "msg %v: %v", msg, err)
@@ -1103,19 +1106,21 @@ func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte {
 	return nil
 }
 
-func (pm *ProtocolManager) addOrGetTxStatus(txs []*types.Transaction, txHashes []common.Hash) []core.TxStatusData {
-	status := pm.txpool.AddOrGetTxStatus(txs, txHashes)
-	for i, _ := range status {
-		blockHash, blockNum, txIndex := core.GetTxLookupEntry(pm.chainDb, txHashes[i])
-		if blockHash != (common.Hash{}) {
-			enc, err := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: blockHash, BlockIndex: blockNum, Index: txIndex})
-			if err != nil {
-				panic(err)
+func (pm *ProtocolManager) txStatus(hashes []common.Hash) []txStatus {
+	stats := make([]txStatus, len(hashes))
+	for i, stat := range pm.txpool.Status(hashes) {
+		// Save the status we've got from the transaction pool
+		stats[i].Status = stat
+
+		// If the transaction is unknown to the pool, try looking it up locally
+		if stat == core.TxStatusUnknown {
+			if block, number, index := core.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) {
+				stats[i].Status = core.TxStatusIncluded
+				stats[i].Lookup = &core.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index}
 			}
-			status[i] = core.TxStatusData{Status: core.TxStatusIncluded, Data: enc}
 		}
 	}
-	return status
+	return stats
 }
 
 // NodeInfo retrieves some protocol metadata about the running host node.

+ 29 - 18
les/handler_test.go

@@ -20,8 +20,8 @@ import (
 	"bytes"
 	"math/big"
 	"math/rand"
-	"runtime"
 	"testing"
+	"time"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core"
@@ -423,7 +423,7 @@ func TestTransactionStatusLes2(t *testing.T) {
 
 	var reqID uint64
 
-	test := func(tx *types.Transaction, send bool, expStatus core.TxStatusData) {
+	test := func(tx *types.Transaction, send bool, expStatus txStatus) {
 		reqID++
 		if send {
 			cost := peer.GetRequestCost(SendTxV2Msg, 1)
@@ -432,7 +432,7 @@ func TestTransactionStatusLes2(t *testing.T) {
 			cost := peer.GetRequestCost(GetTxStatusMsg, 1)
 			sendRequest(peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()})
 		}
-		if err := expectResponse(peer.app, TxStatusMsg, reqID, testBufLimit, []core.TxStatusData{expStatus}); err != nil {
+		if err := expectResponse(peer.app, TxStatusMsg, reqID, testBufLimit, []txStatus{expStatus}); err != nil {
 			t.Errorf("transaction status mismatch")
 		}
 	}
@@ -441,20 +441,20 @@ func TestTransactionStatusLes2(t *testing.T) {
 
 	// test error status by sending an underpriced transaction
 	tx0, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, nil, nil), signer, testBankKey)
-	test(tx0, true, core.TxStatusData{Status: core.TxStatusError, Data: []byte("transaction underpriced")})
+	test(tx0, true, txStatus{Status: core.TxStatusUnknown, Error: core.ErrUnderpriced})
 
 	tx1, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey)
-	test(tx1, false, core.TxStatusData{Status: core.TxStatusUnknown}) // query before sending, should be unknown
-	test(tx1, true, core.TxStatusData{Status: core.TxStatusPending})  // send valid processable tx, should return pending
-	test(tx1, true, core.TxStatusData{Status: core.TxStatusPending})  // adding it again should not return an error
+	test(tx1, false, txStatus{Status: core.TxStatusUnknown}) // query before sending, should be unknown
+	test(tx1, true, txStatus{Status: core.TxStatusPending})  // send valid processable tx, should return pending
+	test(tx1, true, txStatus{Status: core.TxStatusPending})  // adding it again should not return an error
 
 	tx2, _ := types.SignTx(types.NewTransaction(1, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey)
 	tx3, _ := types.SignTx(types.NewTransaction(2, acc1Addr, big.NewInt(10000), bigTxGas, big.NewInt(100000000000), nil), signer, testBankKey)
 	// send transactions in the wrong order, tx3 should be queued
-	test(tx3, true, core.TxStatusData{Status: core.TxStatusQueued})
-	test(tx2, true, core.TxStatusData{Status: core.TxStatusPending})
+	test(tx3, true, txStatus{Status: core.TxStatusQueued})
+	test(tx2, true, txStatus{Status: core.TxStatusPending})
 	// query again, now tx3 should be pending too
-	test(tx3, false, core.TxStatusData{Status: core.TxStatusPending})
+	test(tx3, false, txStatus{Status: core.TxStatusPending})
 
 	// generate and add a block with tx1 and tx2 included
 	gchain, _ := core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 1, func(i int, block *core.BlockGen) {
@@ -464,13 +464,21 @@ func TestTransactionStatusLes2(t *testing.T) {
 	if _, err := chain.InsertChain(gchain); err != nil {
 		panic(err)
 	}
+	// wait until TxPool processes the inserted block
+	for i := 0; i < 10; i++ {
+		if pending, _ := txpool.Stats(); pending == 1 {
+			break
+		}
+		time.Sleep(100 * time.Millisecond)
+	}
+	if pending, _ := txpool.Stats(); pending != 1 {
+		t.Fatalf("pending count mismatch: have %d, want 1", pending)
+	}
 
 	// check if their status is included now
 	block1hash := core.GetCanonicalHash(db, 1)
-	tx1pos, _ := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0})
-	tx2pos, _ := rlp.EncodeToBytes(core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1})
-	test(tx1, false, core.TxStatusData{Status: core.TxStatusIncluded, Data: tx1pos})
-	test(tx2, false, core.TxStatusData{Status: core.TxStatusIncluded, Data: tx2pos})
+	test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}})
+	test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}})
 
 	// create a reorg that rolls them back
 	gchain, _ = core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), db, 2, func(i int, block *core.BlockGen) {})
@@ -478,13 +486,16 @@ func TestTransactionStatusLes2(t *testing.T) {
 		panic(err)
 	}
 	// wait until TxPool processes the reorg
-	for {
+	for i := 0; i < 10; i++ {
 		if pending, _ := txpool.Stats(); pending == 3 {
 			break
 		}
-		runtime.Gosched()
+		time.Sleep(100 * time.Millisecond)
+	}
+	if pending, _ := txpool.Stats(); pending != 3 {
+		t.Fatalf("pending count mismatch: have %d, want 3", pending)
 	}
 	// check if their status is pending again
-	test(tx1, false, core.TxStatusData{Status: core.TxStatusPending})
-	test(tx2, false, core.TxStatusData{Status: core.TxStatusPending})
+	test(tx1, false, txStatus{Status: core.TxStatusPending})
+	test(tx2, false, txStatus{Status: core.TxStatusPending})
 }

+ 2 - 3
les/peer.go

@@ -27,7 +27,6 @@ import (
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
-	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/eth"
 	"github.com/ethereum/go-ethereum/les/flowcontrol"
@@ -233,8 +232,8 @@ func (p *peer) SendHelperTrieProofs(reqID, bv uint64, resp HelperTrieResps) erro
 }
 
 // SendTxStatus sends a batch of transaction status records, corresponding to the ones requested.
-func (p *peer) SendTxStatus(reqID, bv uint64, status []core.TxStatusData) error {
-	return sendResponse(p.rw, TxStatusMsg, reqID, bv, status)
+func (p *peer) SendTxStatus(reqID, bv uint64, stats []txStatus) error {
+	return sendResponse(p.rw, TxStatusMsg, reqID, bv, stats)
 }
 
 // RequestHeadersByHash fetches a batch of blocks' headers corresponding to the

+ 7 - 0
les/protocol.go

@@ -27,6 +27,7 @@ import (
 	"math/big"
 
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/crypto/secp256k1"
 	"github.com/ethereum/go-ethereum/rlp"
@@ -219,3 +220,9 @@ type CodeData []struct {
 }
 
 type proofsData [][]rlp.RawValue
+
+type txStatus struct {
+	Status core.TxStatus
+	Lookup *core.TxLookupEntry
+	Error  error
+}