|
|
@@ -26,6 +26,7 @@ import (
|
|
|
"io/ioutil"
|
|
|
"mime"
|
|
|
"net/http"
|
|
|
+ "net/url"
|
|
|
"sync"
|
|
|
"time"
|
|
|
)
|
|
|
@@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
|
|
|
|
|
|
type httpConn struct {
|
|
|
client *http.Client
|
|
|
- req *http.Request
|
|
|
+ url string
|
|
|
closeOnce sync.Once
|
|
|
closeCh chan interface{}
|
|
|
+ mu sync.Mutex // protects headers
|
|
|
+ headers http.Header
|
|
|
}
|
|
|
|
|
|
// httpConn is treated specially by Client.
|
|
|
@@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error {
|
|
|
}
|
|
|
|
|
|
func (hc *httpConn) remoteAddr() string {
|
|
|
- return hc.req.URL.String()
|
|
|
+ return hc.url
|
|
|
}
|
|
|
|
|
|
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
|
|
|
@@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
|
|
|
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
|
|
|
// using the provided HTTP Client.
|
|
|
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
|
|
|
- req, err := http.NewRequest(http.MethodPost, endpoint, nil)
|
|
|
+ // Sanity check URL so we don't end up with a client that will fail every request.
|
|
|
+ _, err := url.Parse(endpoint)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- req.Header.Set("Content-Type", contentType)
|
|
|
- req.Header.Set("Accept", contentType)
|
|
|
|
|
|
initctx := context.Background()
|
|
|
+ headers := make(http.Header, 2)
|
|
|
+ headers.Set("accept", contentType)
|
|
|
+ headers.Set("content-type", contentType)
|
|
|
return newClient(initctx, func(context.Context) (ServerCodec, error) {
|
|
|
- return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
|
|
|
+ hc := &httpConn{
|
|
|
+ client: client,
|
|
|
+ headers: headers,
|
|
|
+ url: endpoint,
|
|
|
+ closeCh: make(chan interface{}),
|
|
|
+ }
|
|
|
+ return hc, nil
|
|
|
})
|
|
|
}
|
|
|
|
|
|
@@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
|
|
|
if respBody != nil {
|
|
|
buf := new(bytes.Buffer)
|
|
|
if _, err2 := buf.ReadFrom(respBody); err2 == nil {
|
|
|
- return fmt.Errorf("%v %v", err, buf.String())
|
|
|
+ return fmt.Errorf("%v: %v", err, buf.String())
|
|
|
}
|
|
|
}
|
|
|
return err
|
|
|
@@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- req := hc.req.WithContext(ctx)
|
|
|
- req.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+ req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
req.ContentLength = int64(len(body))
|
|
|
|
|
|
+ // set headers
|
|
|
+ hc.mu.Lock()
|
|
|
+ req.Header = hc.headers.Clone()
|
|
|
+ hc.mu.Unlock()
|
|
|
+
|
|
|
+ // do request
|
|
|
resp, err := hc.client.Do(req)
|
|
|
if err != nil {
|
|
|
return nil, err
|