web3_providers.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import json
  2. import threading
  3. from websocket import create_connection, WebSocketConnectionClosedException
  4. from web3.providers.base import JSONBaseProvider
  5. from web3.exceptions import ProviderConnectionError
  6. # 创建一个线程局部存储,确保每个线程有自己的 WebSocket 连接
  7. thread_local_storage = threading.local()
  8. class SyncWebSocketProvider(JSONBaseProvider):
  9. """
  10. A custom synchronous WebSocket provider for web3.py v6+.
  11. It uses the `websocket-client` library to provide a synchronous interface.
  12. """
  13. def __init__(self, endpoint_uri: str, timeout: int = 10):
  14. self.endpoint_uri = endpoint_uri
  15. self.timeout = timeout
  16. super().__init__()
  17. def get_local_connection(self):
  18. # 检查当前线程是否有连接,没有则创建
  19. if not hasattr(thread_local_storage, 'ws_connection'):
  20. try:
  21. thread_local_storage.ws_connection = create_connection(
  22. self.endpoint_uri,
  23. timeout=self.timeout
  24. )
  25. except Exception as e:
  26. raise ProviderConnectionError(f"Could not connect to {self.endpoint_uri}: {e}")
  27. return thread_local_storage.ws_connection
  28. def make_request(self, method: str, params: list) -> dict:
  29. request_data = self.encode_rpc_request(method, params)
  30. try:
  31. conn = self.get_local_connection()
  32. conn.send(request_data)
  33. response_raw = conn.recv()
  34. # 如果连接被意外关闭,recv会抛出异常或返回空
  35. if not response_raw:
  36. # 清除旧连接并重试一次
  37. del thread_local_storage.ws_connection
  38. conn = self.get_local_connection()
  39. conn.send(request_data)
  40. response_raw = conn.recv()
  41. except WebSocketConnectionClosedException:
  42. # 连接已关闭,尝试重新连接并重试一次
  43. del thread_local_storage.ws_connection
  44. conn = self.get_local_connection()
  45. conn.send(request_data)
  46. response_raw = conn.recv()
  47. except Exception as e:
  48. raise ProviderConnectionError(f"Error making request to {self.endpoint_uri}: {e}")
  49. response = self.decode_rpc_response(response_raw.encode('utf-8'))
  50. return response