瀏覽代碼

rlp: fix encReader returning nil buffers to the pool

The bug can cause crashes if Read is called after EOF has been returned.
No code performs such calls right now, but hitting the bug gets more
likely as rlp.EncodeToReader gets used in more places.
Felix Lange 10 年之前
父節點
當前提交
ac32f52ca6
共有 2 個文件被更改,包括 32 次插入4 次删除
  1. 9 4
      rlp/encode.go
  2. 23 0
      rlp/encode_test.go

+ 9 - 4
rlp/encode.go

@@ -90,8 +90,8 @@ func Encode(w io.Writer, val interface{}) error {
 		return outer.encode(val)
 	}
 	eb := encbufPool.Get().(*encbuf)
-	eb.reset()
 	defer encbufPool.Put(eb)
+	eb.reset()
 	if err := eb.encode(val); err != nil {
 		return err
 	}
@@ -102,8 +102,8 @@ func Encode(w io.Writer, val interface{}) error {
 // Please see the documentation of Encode for the encoding rules.
 func EncodeToBytes(val interface{}) ([]byte, error) {
 	eb := encbufPool.Get().(*encbuf)
-	eb.reset()
 	defer encbufPool.Put(eb)
+	eb.reset()
 	if err := eb.encode(val); err != nil {
 		return nil, err
 	}
@@ -288,8 +288,13 @@ type encReader struct {
 func (r *encReader) Read(b []byte) (n int, err error) {
 	for {
 		if r.piece = r.next(); r.piece == nil {
-			encbufPool.Put(r.buf)
-			r.buf = nil
+			// Put the encode buffer back into the pool at EOF when it
+			// is first encountered. Subsequent calls still return EOF
+			// as the error but the buffer is no longer valid.
+			if r.buf != nil {
+				encbufPool.Put(r.buf)
+				r.buf = nil
+			}
 			return n, io.EOF
 		}
 		nn := copy(b[n:], r.piece)

+ 23 - 0
rlp/encode_test.go

@@ -23,6 +23,7 @@ import (
 	"io"
 	"io/ioutil"
 	"math/big"
+	"sync"
 	"testing"
 )
 
@@ -306,3 +307,25 @@ func TestEncodeToReaderPiecewise(t *testing.T) {
 		return output, nil
 	})
 }
+
+// This is a regression test verifying that encReader
+// returns its encbuf to the pool only once.
+func TestEncodeToReaderReturnToPool(t *testing.T) {
+	buf := make([]byte, 50)
+	wg := new(sync.WaitGroup)
+	for i := 0; i < 5; i++ {
+		wg.Add(1)
+		go func() {
+			for i := 0; i < 1000; i++ {
+				_, r, _ := EncodeToReader("foo")
+				ioutil.ReadAll(r)
+				r.Read(buf)
+				r.Read(buf)
+				r.Read(buf)
+				r.Read(buf)
+			}
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+}