websocket.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. // Copyright 2015 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package rpc
  17. import (
  18. "context"
  19. "encoding/base64"
  20. "fmt"
  21. "net/http"
  22. "net/url"
  23. "os"
  24. "strings"
  25. "sync"
  26. "time"
  27. mapset "github.com/deckarep/golang-set"
  28. "github.com/ethereum/go-ethereum/log"
  29. "github.com/gorilla/websocket"
  30. )
  31. const (
  32. wsReadBuffer = 1024
  33. wsWriteBuffer = 1024
  34. wsPingInterval = 60 * time.Second
  35. wsPingWriteTimeout = 5 * time.Second
  36. wsPongTimeout = 30 * time.Second
  37. wsMessageSizeLimit = 15 * 1024 * 1024
  38. )
  39. var wsBufferPool = new(sync.Pool)
  40. // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
  41. //
  42. // allowedOrigins should be a comma-separated list of allowed origin URLs.
  43. // To allow connections with any origin, pass "*".
  44. func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
  45. var upgrader = websocket.Upgrader{
  46. ReadBufferSize: wsReadBuffer,
  47. WriteBufferSize: wsWriteBuffer,
  48. WriteBufferPool: wsBufferPool,
  49. CheckOrigin: wsHandshakeValidator(allowedOrigins),
  50. }
  51. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  52. conn, err := upgrader.Upgrade(w, r, nil)
  53. if err != nil {
  54. log.Debug("WebSocket upgrade failed", "err", err)
  55. return
  56. }
  57. codec := newWebsocketCodec(conn, r.Host, r.Header)
  58. s.ServeCodec(codec, 0)
  59. })
  60. }
  61. // wsHandshakeValidator returns a handler that verifies the origin during the
  62. // websocket upgrade process. When a '*' is specified as an allowed origins all
  63. // connections are accepted.
  64. func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
  65. origins := mapset.NewSet()
  66. allowAllOrigins := false
  67. for _, origin := range allowedOrigins {
  68. if origin == "*" {
  69. allowAllOrigins = true
  70. }
  71. if origin != "" {
  72. origins.Add(origin)
  73. }
  74. }
  75. // allow localhost if no allowedOrigins are specified.
  76. if len(origins.ToSlice()) == 0 {
  77. origins.Add("http://localhost")
  78. if hostname, err := os.Hostname(); err == nil {
  79. origins.Add("http://" + hostname)
  80. }
  81. }
  82. log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
  83. f := func(req *http.Request) bool {
  84. // Skip origin verification if no Origin header is present. The origin check
  85. // is supposed to protect against browser based attacks. Browsers always set
  86. // Origin. Non-browser software can put anything in origin and checking it doesn't
  87. // provide additional security.
  88. if _, ok := req.Header["Origin"]; !ok {
  89. return true
  90. }
  91. // Verify origin against allow list.
  92. origin := strings.ToLower(req.Header.Get("Origin"))
  93. if allowAllOrigins || originIsAllowed(origins, origin) {
  94. return true
  95. }
  96. log.Warn("Rejected WebSocket connection", "origin", origin)
  97. return false
  98. }
  99. return f
  100. }
  101. type wsHandshakeError struct {
  102. err error
  103. status string
  104. }
  105. func (e wsHandshakeError) Error() string {
  106. s := e.err.Error()
  107. if e.status != "" {
  108. s += " (HTTP status " + e.status + ")"
  109. }
  110. return s
  111. }
  112. func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
  113. it := allowedOrigins.Iterator()
  114. for origin := range it.C {
  115. if ruleAllowsOrigin(origin.(string), browserOrigin) {
  116. return true
  117. }
  118. }
  119. return false
  120. }
  121. func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
  122. var (
  123. allowedScheme, allowedHostname, allowedPort string
  124. browserScheme, browserHostname, browserPort string
  125. err error
  126. )
  127. allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
  128. if err != nil {
  129. log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
  130. return false
  131. }
  132. browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
  133. if err != nil {
  134. log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
  135. return false
  136. }
  137. if allowedScheme != "" && allowedScheme != browserScheme {
  138. return false
  139. }
  140. if allowedHostname != "" && allowedHostname != browserHostname {
  141. return false
  142. }
  143. if allowedPort != "" && allowedPort != browserPort {
  144. return false
  145. }
  146. return true
  147. }
  148. func parseOriginURL(origin string) (string, string, string, error) {
  149. parsedURL, err := url.Parse(strings.ToLower(origin))
  150. if err != nil {
  151. return "", "", "", err
  152. }
  153. var scheme, hostname, port string
  154. if strings.Contains(origin, "://") {
  155. scheme = parsedURL.Scheme
  156. hostname = parsedURL.Hostname()
  157. port = parsedURL.Port()
  158. } else {
  159. scheme = ""
  160. hostname = parsedURL.Scheme
  161. port = parsedURL.Opaque
  162. if hostname == "" {
  163. hostname = origin
  164. }
  165. }
  166. return scheme, hostname, port, nil
  167. }
  168. // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
  169. // that is listening on the given endpoint using the provided dialer.
  170. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
  171. endpoint, header, err := wsClientHeaders(endpoint, origin)
  172. if err != nil {
  173. return nil, err
  174. }
  175. return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
  176. conn, resp, err := dialer.DialContext(ctx, endpoint, header)
  177. if err != nil {
  178. hErr := wsHandshakeError{err: err}
  179. if resp != nil {
  180. hErr.status = resp.Status
  181. }
  182. return nil, hErr
  183. }
  184. return newWebsocketCodec(conn, endpoint, header), nil
  185. })
  186. }
  187. // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
  188. // that is listening on the given endpoint.
  189. //
  190. // The context is used for the initial connection establishment. It does not
  191. // affect subsequent interactions with the client.
  192. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
  193. dialer := websocket.Dialer{
  194. ReadBufferSize: wsReadBuffer,
  195. WriteBufferSize: wsWriteBuffer,
  196. WriteBufferPool: wsBufferPool,
  197. }
  198. return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
  199. }
  200. func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
  201. endpointURL, err := url.Parse(endpoint)
  202. if err != nil {
  203. return endpoint, nil, err
  204. }
  205. header := make(http.Header)
  206. if origin != "" {
  207. header.Add("origin", origin)
  208. }
  209. if endpointURL.User != nil {
  210. b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
  211. header.Add("authorization", "Basic "+b64auth)
  212. endpointURL.User = nil
  213. }
  214. return endpointURL.String(), header, nil
  215. }
  216. type websocketCodec struct {
  217. *jsonCodec
  218. conn *websocket.Conn
  219. info PeerInfo
  220. wg sync.WaitGroup
  221. pingReset chan struct{}
  222. }
  223. func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
  224. conn.SetReadLimit(wsMessageSizeLimit)
  225. conn.SetPongHandler(func(appData string) error {
  226. conn.SetReadDeadline(time.Time{})
  227. return nil
  228. })
  229. wc := &websocketCodec{
  230. jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
  231. conn: conn,
  232. pingReset: make(chan struct{}, 1),
  233. info: PeerInfo{
  234. Transport: "ws",
  235. RemoteAddr: conn.RemoteAddr().String(),
  236. },
  237. }
  238. // Fill in connection details.
  239. wc.info.HTTP.Host = host
  240. wc.info.HTTP.Origin = req.Get("Origin")
  241. wc.info.HTTP.UserAgent = req.Get("User-Agent")
  242. // Start pinger.
  243. wc.wg.Add(1)
  244. go wc.pingLoop()
  245. return wc
  246. }
  247. func (wc *websocketCodec) close() {
  248. wc.jsonCodec.close()
  249. wc.wg.Wait()
  250. }
  251. func (wc *websocketCodec) peerInfo() PeerInfo {
  252. return wc.info
  253. }
  254. func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
  255. err := wc.jsonCodec.writeJSON(ctx, v)
  256. if err == nil {
  257. // Notify pingLoop to delay the next idle ping.
  258. select {
  259. case wc.pingReset <- struct{}{}:
  260. default:
  261. }
  262. }
  263. return err
  264. }
  265. // pingLoop sends periodic ping frames when the connection is idle.
  266. func (wc *websocketCodec) pingLoop() {
  267. var timer = time.NewTimer(wsPingInterval)
  268. defer wc.wg.Done()
  269. defer timer.Stop()
  270. for {
  271. select {
  272. case <-wc.closed():
  273. return
  274. case <-wc.pingReset:
  275. if !timer.Stop() {
  276. <-timer.C
  277. }
  278. timer.Reset(wsPingInterval)
  279. case <-timer.C:
  280. wc.jsonCodec.encMu.Lock()
  281. wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
  282. wc.conn.WriteMessage(websocket.PingMessage, nil)
  283. wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout))
  284. wc.jsonCodec.encMu.Unlock()
  285. timer.Reset(wsPingInterval)
  286. }
  287. }
  288. }