Browse Source

rpc: add SetHeader method to Client (#21392)

Resolves #20163

Co-authored-by: Felix Lange <fjl@twurst.com>
rene 5 years ago
parent
commit
290d6bd903
3 changed files with 79 additions and 10 deletions
  1. 14 1
      rpc/client.go
  2. 37 0
      rpc/client_test.go
  3. 28 9
      rpc/http.go

+ 14 - 1
rpc/client.go

@@ -85,7 +85,7 @@ type Client struct {
 
 	// writeConn is used for writing to the connection on the caller's goroutine. It should
 	// only be accessed outside of dispatch, with the write lock held. The write lock is
-	// taken by sending on requestOp and released by sending on sendDone.
+	// taken by sending on reqInit and released by sending on reqSent.
 	writeConn jsonWriter
 
 	// for dispatch
@@ -260,6 +260,19 @@ func (c *Client) Close() {
 	}
 }
 
+// SetHeader adds a custom HTTP header to the client's requests.
+// This method only works for clients using HTTP, it doesn't have
+// any effect for clients using another transport.
+func (c *Client) SetHeader(key, value string) {
+	if !c.isHTTP {
+		return
+	}
+	conn := c.writeConn.(*httpConn)
+	conn.mu.Lock()
+	conn.headers.Set(key, value)
+	conn.mu.Unlock()
+}
+
 // Call performs a JSON-RPC call with the given arguments and unmarshals into
 // result if no error occurred.
 //

+ 37 - 0
rpc/client_test.go

@@ -26,6 +26,7 @@ import (
 	"os"
 	"reflect"
 	"runtime"
+	"strings"
 	"sync"
 	"testing"
 	"time"
@@ -429,6 +430,42 @@ func TestClientNotificationStorm(t *testing.T) {
 	doTest(23000, true)
 }
 
+func TestClientSetHeader(t *testing.T) {
+	var gotHeader bool
+	srv := newTestServer()
+	httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if r.Header.Get("test") == "ok" {
+			gotHeader = true
+		}
+		srv.ServeHTTP(w, r)
+	}))
+	defer httpsrv.Close()
+	defer srv.Stop()
+
+	client, err := Dial(httpsrv.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer client.Close()
+
+	client.SetHeader("test", "ok")
+	if _, err := client.SupportedModules(); err != nil {
+		t.Fatal(err)
+	}
+	if !gotHeader {
+		t.Fatal("client did not set custom header")
+	}
+
+	// Check that Content-Type can be replaced.
+	client.SetHeader("content-type", "application/x-garbage")
+	_, err = client.SupportedModules()
+	if err == nil {
+		t.Fatal("no error for invalid content-type header")
+	} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
+		t.Fatalf("error is not related to content-type: %q", err)
+	}
+}
+
 func TestClientHTTP(t *testing.T) {
 	server := newTestServer()
 	defer server.Stop()

+ 28 - 9
rpc/http.go

@@ -26,6 +26,7 @@ import (
 	"io/ioutil"
 	"mime"
 	"net/http"
+	"net/url"
 	"sync"
 	"time"
 )
@@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
 
 type httpConn struct {
 	client    *http.Client
-	req       *http.Request
+	url       string
 	closeOnce sync.Once
 	closeCh   chan interface{}
+	mu        sync.Mutex // protects headers
+	headers   http.Header
 }
 
 // httpConn is treated specially by Client.
@@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error {
 }
 
 func (hc *httpConn) remoteAddr() string {
-	return hc.req.URL.String()
+	return hc.url
 }
 
 func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
@@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
 // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
 // using the provided HTTP Client.
 func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
-	req, err := http.NewRequest(http.MethodPost, endpoint, nil)
+	// Sanity check URL so we don't end up with a client that will fail every request.
+	_, err := url.Parse(endpoint)
 	if err != nil {
 		return nil, err
 	}
-	req.Header.Set("Content-Type", contentType)
-	req.Header.Set("Accept", contentType)
 
 	initctx := context.Background()
+	headers := make(http.Header, 2)
+	headers.Set("accept", contentType)
+	headers.Set("content-type", contentType)
 	return newClient(initctx, func(context.Context) (ServerCodec, error) {
-		return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
+		hc := &httpConn{
+			client:  client,
+			headers: headers,
+			url:     endpoint,
+			closeCh: make(chan interface{}),
+		}
+		return hc, nil
 	})
 }
 
@@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
 		if respBody != nil {
 			buf := new(bytes.Buffer)
 			if _, err2 := buf.ReadFrom(respBody); err2 == nil {
-				return fmt.Errorf("%v %v", err, buf.String())
+				return fmt.Errorf("%v: %v", err, buf.String())
 			}
 		}
 		return err
@@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
 	if err != nil {
 		return nil, err
 	}
-	req := hc.req.WithContext(ctx)
-	req.Body = ioutil.NopCloser(bytes.NewReader(body))
+	req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
+	if err != nil {
+		return nil, err
+	}
 	req.ContentLength = int64(len(body))
 
+	// set headers
+	hc.mu.Lock()
+	req.Header = hc.headers.Clone()
+	hc.mu.Unlock()
+
+	// do request
 	resp, err := hc.client.Do(req)
 	if err != nil {
 		return nil, err