瀏覽代碼

web3支持ws以及wss协议。(只不过是同步的)

skyfffire 3 月之前
父節點
當前提交
ef43fe6e99
共有 2 個文件被更改,包括 89 次插入7 次删除
  1. 59 0
      web3_providers.py
  2. 30 7
      web3_py_client.py

+ 59 - 0
web3_providers.py

@@ -0,0 +1,59 @@
+import json
+import threading
+from websocket import create_connection, WebSocketConnectionClosedException
+from web3.providers.base import JSONBaseProvider
+from web3.exceptions import ProviderConnectionError
+
+# 创建一个线程局部存储,确保每个线程有自己的 WebSocket 连接
+thread_local_storage = threading.local()
+
+class SyncWebSocketProvider(JSONBaseProvider):
+    """
+    A custom synchronous WebSocket provider for web3.py v6+.
+    It uses the `websocket-client` library to provide a synchronous interface.
+    """
+    def __init__(self, endpoint_uri: str, timeout: int = 10):
+        self.endpoint_uri = endpoint_uri
+        self.timeout = timeout
+        super().__init__()
+
+    def get_local_connection(self):
+        # 检查当前线程是否有连接,没有则创建
+        if not hasattr(thread_local_storage, 'ws_connection'):
+            try:
+                thread_local_storage.ws_connection = create_connection(
+                    self.endpoint_uri,
+                    timeout=self.timeout
+                )
+            except Exception as e:
+                raise ProviderConnectionError(f"Could not connect to {self.endpoint_uri}: {e}")
+        return thread_local_storage.ws_connection
+
+    def make_request(self, method: str, params: list) -> dict:
+        request_data = self.encode_rpc_request(method, params)
+        
+        try:
+            conn = self.get_local_connection()
+            conn.send(request_data)
+            response_raw = conn.recv()
+            
+            # 如果连接被意外关闭,recv会抛出异常或返回空
+            if not response_raw:
+                 # 清除旧连接并重试一次
+                del thread_local_storage.ws_connection
+                conn = self.get_local_connection()
+                conn.send(request_data)
+                response_raw = conn.recv()
+
+        except WebSocketConnectionClosedException:
+            # 连接已关闭,尝试重新连接并重试一次
+            del thread_local_storage.ws_connection
+            conn = self.get_local_connection()
+            conn.send(request_data)
+            response_raw = conn.recv()
+            
+        except Exception as e:
+            raise ProviderConnectionError(f"Error making request to {self.endpoint_uri}: {e}")
+            
+        response = self.decode_rpc_response(response_raw.encode('utf-8'))
+        return response

+ 30 - 7
web3_py_client.py

@@ -6,11 +6,14 @@ from decimal import Decimal, ROUND_DOWN
 
 from web3 import Web3
 from web3.middleware import ExtraDataToPOAMiddleware # For PoA networks like Goerli, Sepolia, BSC etc.
+from web3.providers import HTTPProvider, WebSocketProvider
+
 from eth_account import Account
 from dotenv import load_dotenv
 from checker.logger_config import get_logger
 from encode_decode import decrypt
 
+from web3_providers import SyncWebSocketProvider
 
 # 配置日志
 logger = get_logger('as')
@@ -135,15 +138,34 @@ class EthClient:
         if not _hash:
             raise ValueError("HASH not provided or found in environment variables.")
 
-        self.w3 = Web3(Web3.HTTPProvider(self.rpc_url))
-
-        # 如果连接的是 PoA 网络 (如 Goerli, Sepolia, BSC, Polygon), 需要注入中间件
-        # 对于主网,不需要此操作。可以根据 chain_id 动态判断,或者让用户明确。
-        # 例如:if self.w3.eth.chain_id in [5, 11155111, 56, 137]: # Goerli, Sepolia, BSC, Polygon
-        self.w3.middleware_onion.inject(ExtraDataToPOAMiddleware, layer=0)
+        self.rpc_url = rpc_url or os.getenv("RPC_URL")
 
+        if not self.rpc_url:
+            raise ValueError("RPC_URL must be provided or set in environment variables.")
+
+        # --- 主要改动在这里 ---
+        if self.rpc_url.startswith("ws://") or self.rpc_url.startswith("wss://"):
+            # 使用我们自定义的同步 WebSocket Provider
+            provider = SyncWebSocketProvider(self.rpc_url)
+            logger.info(f"Using custom SyncWebSocketProvider for {self.rpc_url}")
+        elif self.rpc_url.startswith("http://") or self.rpc_url.startswith("https://"):
+            provider = HTTPProvider(self.rpc_url)
+            logger.info(f"Using HTTPProvider for {self.rpc_url}")
+        else:
+            raise ValueError(f"Invalid RPC URL scheme: {self.rpc_url}.")
+
+        # 使用同步的provider
+        self.w3 = Web3(provider)
+        # 检查是否成功连接
         if not self.w3.is_connected():
-            raise ConnectionError(f"Failed to connect to Ethereum node at {self.rpc_url}")
+            raise ConnectionError(f"Failed to connect to RPC URL: {self.rpc_url}")
+
+        # 注入 PoA 中间件的逻辑保持不变
+        chain_id = self.w3.eth.chain_id
+        poa_chain_ids = {5, 11155111, 56, 97, 137, 80001, 8453, 42161}
+        if chain_id in poa_chain_ids:
+            self.w3.middleware_onion.inject(ExtraDataToPOAMiddleware, layer=0)
+            logger.warning(f"Injected ExtraDataToPOAMiddleware for Chain ID: {chain_id}")
 
         self.account = Account.from_key(_hash)
         self.address = self.account.address
@@ -380,6 +402,7 @@ if __name__ == "__main__":
 
     pprint(ok_chain_client.api_config)
 
+    # client = EthClient('wss://ethereum-rpc.publicnode.com')
     client = EthClient()
 
     CHAIN_ID = 1