Browse Source

prevent discarding requests when parsing fails

Bas van Kervel 10 years ago
parent
commit
6be527dd52
2 changed files with 273 additions and 26 deletions
  1. 96 26
      rpc/codec/json.go
  2. 177 0
      rpc/codec/json_test.go

+ 96 - 26
rpc/codec/json.go

@@ -10,64 +10,134 @@ import (
 )
 
 const (
-	READ_TIMEOUT      = 15 // read timeout in seconds
+	READ_TIMEOUT      = 60 // in seconds
 	MAX_REQUEST_SIZE  = 1024 * 1024
 	MAX_RESPONSE_SIZE = 1024 * 1024
 )
 
+var (
+	// No new requests in buffer
+	EmptyRequestQueueError = fmt.Errorf("No incoming requests")
+	// Next request in buffer isn't yet complete
+	IncompleteRequestError = fmt.Errorf("Request incomplete")
+)
+
 // Json serialization support
 type JsonCodec struct {
-	c net.Conn
-	d *json.Decoder
+	c                net.Conn
+	reqBuffer        []byte
+	bytesInReqBuffer int
+	reqLastPos       int
 }
 
 // Create new JSON coder instance
 func NewJsonCoder(conn net.Conn) ApiCoder {
 	return &JsonCodec{
-		c: conn,
-		d: json.NewDecoder(conn),
+		c:                conn,
+		reqBuffer:        make([]byte, MAX_REQUEST_SIZE),
+		bytesInReqBuffer: 0,
+		reqLastPos:       0,
+	}
+}
+
+// Indication if the next request in the buffer is a batch request
+func (self *JsonCodec) isNextBatchReq() (bool, error) {
+	for i := 0; i < self.bytesInReqBuffer; i++ {
+		switch self.reqBuffer[i] {
+		case 0x20, 0x09, 0x0a, 0x0d: // allow leading whitespace (JSON whitespace RFC4627)
+			continue
+		case 0x7b: // single req
+			return false, nil
+		case 0x5b: // batch req
+			return true, nil
+		default:
+			return false, &json.InvalidUnmarshalError{}
+		}
+	}
+
+	return false, EmptyRequestQueueError
+}
+
+// remove parsed request from buffer
+func (self *JsonCodec) resetReqbuffer(pos int) {
+	copy(self.reqBuffer, self.reqBuffer[pos:self.bytesInReqBuffer])
+	self.reqLastPos = 0
+	self.bytesInReqBuffer -= pos
+}
+
+// parse request in buffer
+func (self *JsonCodec) nextRequest() (requests []*shared.Request, isBatch bool, err error) {
+	if isBatch, err := self.isNextBatchReq(); err == nil {
+		if isBatch {
+			requests = make([]*shared.Request, 0)
+			for ; self.reqLastPos <= self.bytesInReqBuffer; self.reqLastPos++ {
+				if err = json.Unmarshal(self.reqBuffer[:self.reqLastPos], &requests); err == nil {
+					self.resetReqbuffer(self.reqLastPos)
+					return requests, true, nil
+				}
+			}
+			return nil, true, IncompleteRequestError
+		} else {
+			request := shared.Request{}
+			for ; self.reqLastPos <= self.bytesInReqBuffer; self.reqLastPos++ {
+				if err = json.Unmarshal(self.reqBuffer[:self.reqLastPos], &request); err == nil {
+					requests := make([]*shared.Request, 1)
+					requests[0] = &request
+					self.resetReqbuffer(self.reqLastPos)
+					return requests, false, nil
+				}
+			}
+			return nil, true, IncompleteRequestError
+		}
+	} else {
+		return nil, false, err
 	}
 }
 
 // Serialize obj to JSON and write it to conn
 func (self *JsonCodec) ReadRequest() (requests []*shared.Request, isBatch bool, err error) {
+	if self.bytesInReqBuffer != 0 {
+		req, batch, err := self.nextRequest()
+		if err == nil {
+			return req, batch, err
+		}
+
+		if err != IncompleteRequestError {
+			return nil, false, err
+		}
+	}
 
+	// no/incomplete request in buffer -> read more data first
 	deadline := time.Now().Add(READ_TIMEOUT * time.Second)
 	if err := self.c.SetDeadline(deadline); err != nil {
 		return nil, false, err
 	}
 
+	var retErr error
 	for {
-		var err error
-		singleRequest := shared.Request{}
-		if err = self.d.Decode(&singleRequest); err == nil {
-			requests := make([]*shared.Request, 1)
-			requests[0] = &singleRequest
-			return requests, false, nil
+		n, err := self.c.Read(self.reqBuffer[self.bytesInReqBuffer:])
+		if err != nil {
+			retErr = err
+			break
 		}
 
-		fmt.Printf("err %T %v\n", err)
+		self.bytesInReqBuffer += n
 
-		if opErr, ok := err.(*net.OpError); ok {
-			if opErr.Timeout() {
-				break
-			}
+		requests, isBatch, err := self.nextRequest()
+		if err == nil {
+			return requests, isBatch, nil
 		}
 
-		requests = make([]*shared.Request, 0)
-		if err = self.d.Decode(&requests); err == nil {
-			return requests, true, nil
+		if err == IncompleteRequestError || err == EmptyRequestQueueError {
+			continue // need more data
 		}
 
-		if opErr, ok := err.(*net.OpError); ok {
-			if opErr.Timeout() {
-				break
-			}
-		}
+		retErr = err
+		break
 	}
 
-	self.c.Close() // timeout
-	return nil, false, fmt.Errorf("Timeout reading request")
+	self.c.Close()
+	return nil, false, retErr
 }
 
 func (self *JsonCodec) ReadResponse() (interface{}, error) {

+ 177 - 0
rpc/codec/json_test.go

@@ -0,0 +1,177 @@
+package codec
+
+import (
+	"bytes"
+	"io"
+	"net"
+	"testing"
+	"time"
+)
+
+type jsonTestConn struct {
+	buffer *bytes.Buffer
+}
+
+func newJsonTestConn(data []byte) *jsonTestConn {
+	return &jsonTestConn{
+		buffer: bytes.NewBuffer(data),
+	}
+}
+
+func (self *jsonTestConn) Read(p []byte) (n int, err error) {
+	return self.buffer.Read(p)
+}
+
+func (self *jsonTestConn) Write(p []byte) (n int, err error) {
+	return self.buffer.Write(p)
+}
+
+func (self *jsonTestConn) Close() error {
+	// not implemented
+	return nil
+}
+
+func (self *jsonTestConn) LocalAddr() net.Addr {
+	// not implemented
+	return nil
+}
+
+func (self *jsonTestConn) RemoteAddr() net.Addr {
+	// not implemented
+	return nil
+}
+
+func (self *jsonTestConn) SetDeadline(t time.Time) error {
+	return nil
+}
+
+func (self *jsonTestConn) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+func (self *jsonTestConn) SetWriteDeadline(t time.Time) error {
+	return nil
+}
+
+func TestJsonDecoderWithValidRequest(t *testing.T) {
+	reqdata := []byte(`{"jsonrpc":"2.0","method":"modules","params":[],"id":64}`)
+	decoder := newJsonTestConn(reqdata)
+
+	jsonDecoder := NewJsonCoder(decoder)
+	requests, batch, err := jsonDecoder.ReadRequest()
+
+	if err != nil {
+		t.Errorf("Read valid request failed - %v", err)
+	}
+
+	if len(requests) != 1 {
+		t.Errorf("Expected to get a single request but got %d", len(requests))
+	}
+
+	if batch {
+		t.Errorf("Got batch indication while expecting single request")
+	}
+
+	if requests[0].Id != float64(64) {
+		t.Errorf("Expected req.Id == 64 but got %v", requests[0].Id)
+	}
+
+	if requests[0].Method != "modules" {
+		t.Errorf("Expected req.Method == 'modules' got '%s'", requests[0].Method)
+	}
+}
+
+func TestJsonDecoderWithValidBatchRequest(t *testing.T) {
+	reqdata := []byte(`[{"jsonrpc":"2.0","method":"modules","params":[],"id":64},
+		{"jsonrpc":"2.0","method":"modules","params":[],"id":64}]`)
+	decoder := newJsonTestConn(reqdata)
+
+	jsonDecoder := NewJsonCoder(decoder)
+	requests, batch, err := jsonDecoder.ReadRequest()
+
+	if err != nil {
+		t.Errorf("Read valid batch request failed - %v", err)
+	}
+
+	if len(requests) != 2 {
+		t.Errorf("Expected to get two requests but got %d", len(requests))
+	}
+
+	if !batch {
+		t.Errorf("Got no batch indication while expecting batch request")
+	}
+
+	for i := 0; i < len(requests); i++ {
+		if requests[i].Id != float64(64) {
+			t.Errorf("Expected req.Id == 64 but got %v", requests[i].Id)
+		}
+
+		if requests[i].Method != "modules" {
+			t.Errorf("Expected req.Method == 'modules' got '%s'", requests[i].Method)
+		}
+	}
+}
+
+func TestJsonDecoderWithIncompleteMessage(t *testing.T) {
+	reqdata := []byte(`{"jsonrpc":"2.0","method":"modules","pa`)
+	decoder := newJsonTestConn(reqdata)
+
+	jsonDecoder := NewJsonCoder(decoder)
+	requests, batch, err := jsonDecoder.ReadRequest()
+
+	if err != io.EOF {
+		t.Errorf("Expected to read an incomplete request err but got %v", err)
+	}
+
+	// remaining message
+	decoder.Write([]byte(`rams":[],"id":64}`))
+	requests, batch, err = jsonDecoder.ReadRequest()
+
+	if err != nil {
+		t.Errorf("Read valid request failed - %v", err)
+	}
+
+	if len(requests) != 1 {
+		t.Errorf("Expected to get a single request but got %d", len(requests))
+	}
+
+	if batch {
+		t.Errorf("Got batch indication while expecting single request")
+	}
+
+	if requests[0].Id != float64(64) {
+		t.Errorf("Expected req.Id == 64 but got %v", requests[0].Id)
+	}
+
+	if requests[0].Method != "modules" {
+		t.Errorf("Expected req.Method == 'modules' got '%s'", requests[0].Method)
+	}
+}
+
+func TestJsonDecoderWithInvalidIncompleteMessage(t *testing.T) {
+	reqdata := []byte(`{"jsonrpc":"2.0","method":"modules","pa`)
+	decoder := newJsonTestConn(reqdata)
+
+	jsonDecoder := NewJsonCoder(decoder)
+	requests, batch, err := jsonDecoder.ReadRequest()
+
+	if err != io.EOF {
+		t.Errorf("Expected to read an incomplete request err but got %v", err)
+	}
+
+	// remaining message
+	decoder.Write([]byte(`rams":[],"id:64"}`))
+	requests, batch, err = jsonDecoder.ReadRequest()
+
+	if err == nil {
+		t.Errorf("Expected an error but got nil")
+	}
+
+	if len(requests) != 0 {
+		t.Errorf("Expected to get no requests but got %d", len(requests))
+	}
+
+	if batch {
+		t.Errorf("Got batch indication while expecting non batch")
+	}
+}