websocket_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. // Copyright 2018 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package rpc
  17. import (
  18. "context"
  19. "errors"
  20. "io"
  21. "net"
  22. "net/http"
  23. "net/http/httptest"
  24. "net/http/httputil"
  25. "net/url"
  26. "strings"
  27. "sync/atomic"
  28. "testing"
  29. "time"
  30. "github.com/gorilla/websocket"
  31. )
  32. func TestWebsocketClientHeaders(t *testing.T) {
  33. t.Parallel()
  34. endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com")
  35. if err != nil {
  36. t.Fatalf("wsGetConfig failed: %s", err)
  37. }
  38. if endpoint != "wss://example.com:1234" {
  39. t.Fatal("User should have been stripped from the URL")
  40. }
  41. if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" {
  42. t.Fatal("Basic auth header is incorrect")
  43. }
  44. if header.Get("origin") != "https://example.com" {
  45. t.Fatal("Origin not set")
  46. }
  47. }
  48. // This test checks that the server rejects connections from disallowed origins.
  49. func TestWebsocketOriginCheck(t *testing.T) {
  50. t.Parallel()
  51. var (
  52. srv = newTestServer()
  53. httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}))
  54. wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
  55. )
  56. defer srv.Stop()
  57. defer httpsrv.Close()
  58. client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com")
  59. if err == nil {
  60. client.Close()
  61. t.Fatal("no error for wrong origin")
  62. }
  63. wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"}
  64. if !errors.Is(err, wantErr) {
  65. t.Fatalf("wrong error for wrong origin: %q", err)
  66. }
  67. // Connections without origin header should work.
  68. client, err = DialWebsocket(context.Background(), wsURL, "")
  69. if err != nil {
  70. t.Fatalf("error for empty origin: %v", err)
  71. }
  72. client.Close()
  73. }
  74. // This test checks whether calls exceeding the request size limit are rejected.
  75. func TestWebsocketLargeCall(t *testing.T) {
  76. t.Parallel()
  77. var (
  78. srv = newTestServer()
  79. httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
  80. wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
  81. )
  82. defer srv.Stop()
  83. defer httpsrv.Close()
  84. client, err := DialWebsocket(context.Background(), wsURL, "")
  85. if err != nil {
  86. t.Fatalf("can't dial: %v", err)
  87. }
  88. defer client.Close()
  89. // This call sends slightly less than the limit and should work.
  90. var result echoResult
  91. arg := strings.Repeat("x", maxRequestContentLength-200)
  92. if err := client.Call(&result, "test_echo", arg, 1); err != nil {
  93. t.Fatalf("valid call didn't work: %v", err)
  94. }
  95. if result.String != arg {
  96. t.Fatal("wrong string echoed")
  97. }
  98. // This call sends twice the allowed size and shouldn't work.
  99. arg = strings.Repeat("x", maxRequestContentLength*2)
  100. err = client.Call(&result, "test_echo", arg)
  101. if err == nil {
  102. t.Fatal("no error for too large call")
  103. }
  104. }
  105. func TestWebsocketPeerInfo(t *testing.T) {
  106. var (
  107. s = newTestServer()
  108. ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"}))
  109. tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:")
  110. )
  111. defer s.Stop()
  112. defer ts.Close()
  113. ctx := context.Background()
  114. c, err := DialWebsocket(ctx, tsurl, "origin.example.com")
  115. if err != nil {
  116. t.Fatal(err)
  117. }
  118. // Request peer information.
  119. var connInfo PeerInfo
  120. if err := c.Call(&connInfo, "test_peerInfo"); err != nil {
  121. t.Fatal(err)
  122. }
  123. if connInfo.RemoteAddr == "" {
  124. t.Error("RemoteAddr not set")
  125. }
  126. if connInfo.Transport != "ws" {
  127. t.Errorf("wrong Transport %q", connInfo.Transport)
  128. }
  129. if connInfo.HTTP.UserAgent != "Go-http-client/1.1" {
  130. t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent)
  131. }
  132. if connInfo.HTTP.Origin != "origin.example.com" {
  133. t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent)
  134. }
  135. }
  136. // This test checks that client handles WebSocket ping frames correctly.
  137. func TestClientWebsocketPing(t *testing.T) {
  138. t.Parallel()
  139. var (
  140. sendPing = make(chan struct{})
  141. server = wsPingTestServer(t, sendPing)
  142. ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
  143. )
  144. defer cancel()
  145. defer server.Shutdown(ctx)
  146. client, err := DialContext(ctx, "ws://"+server.Addr)
  147. if err != nil {
  148. t.Fatalf("client dial error: %v", err)
  149. }
  150. defer client.Close()
  151. resultChan := make(chan int)
  152. sub, err := client.EthSubscribe(ctx, resultChan, "foo")
  153. if err != nil {
  154. t.Fatalf("client subscribe error: %v", err)
  155. }
  156. // Note: Unsubscribe is not called on this subscription because the mockup
  157. // server can't handle the request.
  158. // Wait for the context's deadline to be reached before proceeding.
  159. // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798
  160. <-ctx.Done()
  161. close(sendPing)
  162. // Wait for the subscription result.
  163. timeout := time.NewTimer(5 * time.Second)
  164. defer timeout.Stop()
  165. for {
  166. select {
  167. case err := <-sub.Err():
  168. t.Error("client subscription error:", err)
  169. case result := <-resultChan:
  170. t.Log("client got result:", result)
  171. return
  172. case <-timeout.C:
  173. t.Error("didn't get any result within the test timeout")
  174. return
  175. }
  176. }
  177. }
  178. // This checks that the websocket transport can deal with large messages.
  179. func TestClientWebsocketLargeMessage(t *testing.T) {
  180. var (
  181. srv = NewServer()
  182. httpsrv = httptest.NewServer(srv.WebsocketHandler(nil))
  183. wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
  184. )
  185. defer srv.Stop()
  186. defer httpsrv.Close()
  187. respLength := wsMessageSizeLimit - 50
  188. srv.RegisterName("test", largeRespService{respLength})
  189. c, err := DialWebsocket(context.Background(), wsURL, "")
  190. if err != nil {
  191. t.Fatal(err)
  192. }
  193. var r string
  194. if err := c.Call(&r, "test_largeResp"); err != nil {
  195. t.Fatal("call failed:", err)
  196. }
  197. if len(r) != respLength {
  198. t.Fatalf("response has wrong length %d, want %d", len(r), respLength)
  199. }
  200. }
  201. func TestClientWebsocketSevered(t *testing.T) {
  202. t.Parallel()
  203. var (
  204. server = wsPingTestServer(t, nil)
  205. ctx = context.Background()
  206. )
  207. defer server.Shutdown(ctx)
  208. u, err := url.Parse("http://" + server.Addr)
  209. if err != nil {
  210. t.Fatal(err)
  211. }
  212. rproxy := httputil.NewSingleHostReverseProxy(u)
  213. var severable *severableReadWriteCloser
  214. rproxy.ModifyResponse = func(response *http.Response) error {
  215. severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)}
  216. response.Body = severable
  217. return nil
  218. }
  219. frontendProxy := httptest.NewServer(rproxy)
  220. defer frontendProxy.Close()
  221. wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:")
  222. client, err := DialWebsocket(ctx, wsURL, "")
  223. if err != nil {
  224. t.Fatalf("client dial error: %v", err)
  225. }
  226. defer client.Close()
  227. resultChan := make(chan int)
  228. sub, err := client.EthSubscribe(ctx, resultChan, "foo")
  229. if err != nil {
  230. t.Fatalf("client subscribe error: %v", err)
  231. }
  232. // sever the connection
  233. severable.Sever()
  234. // Wait for subscription error.
  235. timeout := time.NewTimer(3 * wsPingInterval)
  236. defer timeout.Stop()
  237. for {
  238. select {
  239. case err := <-sub.Err():
  240. t.Log("client subscription error:", err)
  241. return
  242. case result := <-resultChan:
  243. t.Error("unexpected result:", result)
  244. return
  245. case <-timeout.C:
  246. t.Error("didn't get any error within the test timeout")
  247. return
  248. }
  249. }
  250. }
  251. // wsPingTestServer runs a WebSocket server which accepts a single subscription request.
  252. // When a value arrives on sendPing, the server sends a ping frame, waits for a matching
  253. // pong and finally delivers a single subscription result.
  254. func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server {
  255. var srv http.Server
  256. shutdown := make(chan struct{})
  257. srv.RegisterOnShutdown(func() {
  258. close(shutdown)
  259. })
  260. srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  261. // Upgrade to WebSocket.
  262. upgrader := websocket.Upgrader{
  263. CheckOrigin: func(r *http.Request) bool { return true },
  264. }
  265. conn, err := upgrader.Upgrade(w, r, nil)
  266. if err != nil {
  267. t.Errorf("server WS upgrade error: %v", err)
  268. return
  269. }
  270. defer conn.Close()
  271. // Handle the connection.
  272. wsPingTestHandler(t, conn, shutdown, sendPing)
  273. })
  274. // Start the server.
  275. listener, err := net.Listen("tcp", "127.0.0.1:0")
  276. if err != nil {
  277. t.Fatal("can't listen:", err)
  278. }
  279. srv.Addr = listener.Addr().String()
  280. go srv.Serve(listener)
  281. return &srv
  282. }
  283. func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) {
  284. // Canned responses for the eth_subscribe call in TestClientWebsocketPing.
  285. const (
  286. subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}`
  287. subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}`
  288. )
  289. // Handle subscribe request.
  290. if _, _, err := conn.ReadMessage(); err != nil {
  291. t.Errorf("server read error: %v", err)
  292. return
  293. }
  294. if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil {
  295. t.Errorf("server write error: %v", err)
  296. return
  297. }
  298. // Read from the connection to process control messages.
  299. var pongCh = make(chan string)
  300. conn.SetPongHandler(func(d string) error {
  301. t.Logf("server got pong: %q", d)
  302. pongCh <- d
  303. return nil
  304. })
  305. go func() {
  306. for {
  307. typ, msg, err := conn.ReadMessage()
  308. if err != nil {
  309. return
  310. }
  311. t.Logf("server got message (%d): %q", typ, msg)
  312. }
  313. }()
  314. // Write messages.
  315. var (
  316. wantPong string
  317. timer = time.NewTimer(0)
  318. )
  319. defer timer.Stop()
  320. <-timer.C
  321. for {
  322. select {
  323. case _, open := <-sendPing:
  324. if !open {
  325. sendPing = nil
  326. }
  327. t.Logf("server sending ping")
  328. conn.WriteMessage(websocket.PingMessage, []byte("ping"))
  329. wantPong = "ping"
  330. case data := <-pongCh:
  331. if wantPong == "" {
  332. t.Errorf("unexpected pong")
  333. } else if data != wantPong {
  334. t.Errorf("got pong with wrong data %q", data)
  335. }
  336. wantPong = ""
  337. timer.Reset(200 * time.Millisecond)
  338. case <-timer.C:
  339. t.Logf("server sending response")
  340. conn.WriteMessage(websocket.TextMessage, []byte(subNotify))
  341. case <-shutdown:
  342. conn.Close()
  343. return
  344. }
  345. }
  346. }
  347. // severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty.
  348. type severableReadWriteCloser struct {
  349. io.ReadWriteCloser
  350. severed int32 // atomic
  351. }
  352. func (s *severableReadWriteCloser) Sever() {
  353. atomic.StoreInt32(&s.severed, 1)
  354. }
  355. func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) {
  356. if atomic.LoadInt32(&s.severed) > 0 {
  357. return 0, nil
  358. }
  359. return s.ReadWriteCloser.Read(p)
  360. }
  361. func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) {
  362. if atomic.LoadInt32(&s.severed) > 0 {
  363. return len(p), nil
  364. }
  365. return s.ReadWriteCloser.Write(p)
  366. }
  367. func (s *severableReadWriteCloser) Close() error {
  368. return s.ReadWriteCloser.Close()
  369. }