Browse Source

feat(websocket): 添加WebSocket连接管理和消息发送功能

在TradingStrategy类中添加WebSocket连接引用和消息发送方法,用于与交易所进行实时通信
重构listener模块,将Binance数据收集从轮询改为WebSocket订阅模式
优化order book订阅逻辑,仅订阅配置的目标交易对
skyfffire 4 days ago
parent
commit
b9f0bafd1d
2 changed files with 148 additions and 82 deletions
  1. 107 82
      src/leadlag/listener.py
  2. 41 0
      src/leadlag/strategy.py

+ 107 - 82
src/leadlag/listener.py

@@ -15,7 +15,6 @@ import json
 import time
 import logging
 import os
-import requests
 from datetime import datetime
 from strategy import TradingStrategy
 from config import load_config
@@ -77,11 +76,10 @@ logging.getLogger("strategy").setLevel(logging.INFO)
 logger = logging.getLogger("market_data_recorder")
 logger.setLevel(logging.INFO)
 
-# API接口地址
+# API URLs
 LIGHTER_API_URL = "https://mainnet.zklighter.elliot.ai/api/v1/exchangeStats"
 LIGHTER_ORDERBOOKS_URL = "https://mainnet.zklighter.elliot.ai/api/v1/orderBooks"
 LIGHTER_WEBSOCKET_URL = "wss://mainnet.zklighter.elliot.ai/stream"
-BINANCE_TICKER_PRICE_URL = "https://fapi.binance.com/fapi/v2/ticker/price"
 
 # 轮询间隔(秒)
 POLLING_INTERVAL = 1
@@ -107,6 +105,9 @@ binance_data_cache = {
     'latest_prices': {}     # 存储最新价格数据 {symbol: price}
 }
 
+# 全局WebSocket连接,供strategy使用
+lighter_websocket = None
+
 # 全局策略实例
 trading_strategy = None
 
@@ -130,22 +131,6 @@ async def fetch_lighter_orderbooks(session):
         return None
 
 
-async def fetch_binance_ticker_price(session):
-    """从Binance获取最新价格数据"""
-    try:
-        proxy = 'http://' + PROXY_ADDRESS if PROXY_ADDRESS else None
-        async with session.get(BINANCE_TICKER_PRICE_URL, proxy=proxy) as response:
-            if response.status == 200:
-                data = await response.json()
-                return data
-            else:
-                logger.error(f"获取Binance最新价格数据失败: HTTP {response.status}")
-                return None
-    except Exception as e:
-        logger.error(f"获取Binance最新价格数据时出错: {str(e)}")
-        return None
-
-
 def update_market_id_mapping(orderbooks_data):
     """更新market_id到orderbook(市场信息)的映射"""
     global market_id_to_orderbook
@@ -166,58 +151,102 @@ def update_market_id_mapping(orderbooks_data):
         logger.error(f"更新market_id映射时出错: {str(e)}")
 
 
-async def handle_binance_data_collection():
-    """处理Binance数据收集的主循环,每100ms请求一次"""
-    logger.info("开始Binance数据收集任务")
+def get_market_index_by_symbol(target_symbol):
+    """根据配置的币对查找对应的market_index"""
+    for market_id, orderbook in market_id_to_orderbook.items():
+        symbol = orderbook.get('symbol', '')
+        if symbol == target_symbol:
+            return market_id
+    logger.warning(f"未找到币对 {target_symbol} 对应的market_index")
+    return None
+
+
+async def handle_binance_websocket(config):
+    """处理Binance WebSocket连接,订阅aggTrade数据"""
+    
+    # 获取配置的交易对,添加USDT后缀
+    target_symbol = config.get('strategy', 'target_symbol')
+    binance_symbol = f"{target_symbol}USDT"
+    
+    logger.info(f"开始Binance WebSocket连接,订阅 {binance_symbol}@aggTrade")
     
     while True:
         try:
-            async with aiohttp.ClientSession() as session:
-                # 获取最新价格数据
-                price_data = await fetch_binance_ticker_price(session)
-                
-                # 处理最新价格数据
-                if isinstance(price_data, list) and price_data:
-                    for item in price_data:
-                        symbol = item.get('symbol')
-                        price = item.get('price')
-                        if symbol and price:
-                            binance_data_cache['latest_prices'][symbol] = float(price)
-                
-                # 触发策略更新
-                await trigger_strategy_update()
-                
-                # 每100ms请求一次
-                await asyncio.sleep(0.1)
+            # Binance WebSocket URL
+            binance_ws_url = f"wss://fstream.binance.com/ws/{binance_symbol.lower()}@aggTrade"
+            
+            logger.info(f"连接到Binance WebSocket: {binance_ws_url}")
+            async with websockets.connect(binance_ws_url) as websocket:
+                logger.info(f"Binance WebSocket连接成功,开始接收 {binance_symbol} 的aggTrade数据...")
                 
+                async for message in websocket:
+                    try:
+                        data = json.loads(message)
+                        
+                        # 处理aggTrade数据
+                        if data.get("e") == "aggTrade":
+                            symbol = data.get("s")  # 交易对
+                            price = data.get("p")   # 成交价格
+                            
+                            if symbol and price:
+                                binance_data_cache['latest_prices'][symbol] = float(price)
+                                logger.debug(f"更新Binance价格: {symbol} = {price}")
+                                
+                                # 触发策略更新
+                                await trigger_strategy_update()
+                        
+                    except json.JSONDecodeError as e:
+                        logger.error(f"解析Binance WebSocket消息失败: {str(e)}")
+                    except Exception as e:
+                        logger.error(f"处理Binance WebSocket消息时出错: {str(e)}")
+        
+        except websockets.exceptions.ConnectionClosed:
+            logger.warning("Binance WebSocket连接已关闭,尝试重新连接...")
+            await asyncio.sleep(5)
         except Exception as e:
-            import traceback
-            error_info = traceback.format_exc()
-            logger.error(f"Binance数据收集出错: {str(e)}\n{error_info}")
-            await asyncio.sleep(1)  # 出错时等待1秒再重试
+            logger.error(f"Binance WebSocket连接出错: {str(e)}")
+            await asyncio.sleep(5)
 
 
-async def handle_market_stats_websocket():
+async def handle_order_book_websocket(config):
     """
-    处理Lighter Market Stats WebSocket连接
-    订阅所有市场的market_stats数据并更新缓存
+    处理Lighter Order Book WebSocket连接
+    根据配置的币对订阅对应的order_book数据
     """
+    global lighter_websocket
+    
+    # 获取配置的交易对
+    target_symbol = config.get('strategy', 'target_symbol')
+    
     while True:
         try:
             # 构建代理参数
             proxy = 'http://' + PROXY_ADDRESS if PROXY_ADDRESS else None
             
-            logger.info("连接到Lighter Market Stats WebSocket...")
+            logger.info("连接到Lighter Order Book WebSocket...")
             async with websockets.connect(LIGHTER_WEBSOCKET_URL, proxy=proxy) as websocket:
-                # 订阅所有市场的market_stats
+                lighter_websocket = websocket  # 保存WebSocket连接供strategy使用
+                
+                # 将WebSocket连接传递给strategy
+                if trading_strategy:
+                    trading_strategy.set_websocket_connection(websocket)
+                
+                # 根据配置的币对查找market_index
+                market_index = get_market_index_by_symbol(target_symbol)
+                if market_index is None:
+                    logger.error(f"无法找到币对 {target_symbol} 的market_index,无法订阅")
+                    await asyncio.sleep(5)
+                    continue
+                
+                # 订阅指定市场的order_book
                 subscribe_message = {
                     "type": "subscribe",
-                    "channel": "market_stats/all"
+                    "channel": f"order_book/{market_index}"
                 }
                 await websocket.send(json.dumps(subscribe_message))
-                logger.info("已订阅所有市场的Market Stats")
+                logger.info(f"已订阅币对 {target_symbol} (market_index: {market_index}) 的Order Book")
                 
-                logger.info("Market Stats WebSocket连接成功,开始接收数据...")
+                logger.info("Order Book WebSocket连接成功,开始接收数据...")
                 
                 async for message in websocket:
                     try:
@@ -234,31 +263,25 @@ async def handle_market_stats_websocket():
                         elif message_type == "connected":
                             logger.info("WebSocket连接已确认")
                         
-                        # 处理market_stats数据
-                        elif message_type == "update/market_stats" and "market_stats" in data:
-                            market_stats_data = data["market_stats"]
+                        # 处理order_book数据
+                        elif message_type == "update/order_book" and "order_book" in data:
+                            order_book_data = data["order_book"]
+                            channel = data.get("channel", "")
+                            offset = data.get("offset", 0)
                             
-                            # market_stats是一个字典,键是market_id字符串,值是市场数据
-                            for market_id_str, market_data in market_stats_data.items():
-                                market_id = int(market_id_str)  # 将字符串转换为整数
-                                
-                                # 更新缓存,只保留最新数据
-                                market_stats_cache[market_id] = {
-                                    "last_trade_price": market_data.get("last_trade_price"),
-                                    "timestamp": time.time()
-                                }
-                                
-                                symbol = market_id_to_orderbook.get(market_id, {}).get('symbol', f"UNKNOWN_{market_id}")
-                                logger.debug(f"更新Market Stats缓存 - {symbol}(ID:{market_id}): last_trade_price={market_data.get('last_trade_price')}")
+                            # 更新order book缓存
+                            market_stats_cache[market_index] = {
+                                "order_book": order_book_data,
+                                "channel": channel,
+                                "offset": offset,
+                                "timestamp": time.time()
+                            }
                             
-                            # # 触发策略更新
-                            # trigger_strategy_update()
-                        
-                        # 处理订阅确认消息
-                        elif message_type == "subscribed/market_stats":
-                            logger.info("Market Stats订阅确认")
+                            logger.debug(f"收到Order Book更新: market_index={market_index}, offset={offset}")
+                            
+                            # 触发策略更新
+                            await trigger_strategy_update()
                         
-                        # 处理其他未知消息类型
                         else:
                             logger.debug(f"收到未处理的消息类型: {message_type}")
                     
@@ -266,12 +289,14 @@ async def handle_market_stats_websocket():
                         logger.error(f"解析WebSocket消息失败: {str(e)}")
                     except Exception as e:
                         logger.error(f"处理WebSocket消息时出错: {str(e)}")
-                        
+        
         except websockets.exceptions.ConnectionClosed:
-            logger.warning("Market Stats WebSocket连接断开,5秒后重连...")
+            logger.warning("WebSocket连接已关闭,尝试重新连接...")
+            lighter_websocket = None
             await asyncio.sleep(5)
         except Exception as e:
-            logger.error(f"Market Stats WebSocket连接出错: {str(e)},5秒后重连...")
+            logger.error(f"WebSocket连接出错: {str(e)}")
+            lighter_websocket = None
             await asyncio.sleep(5)
 
 
@@ -392,13 +417,13 @@ async def main():
         orderbooks_data = await fetch_lighter_orderbooks(session)
         update_market_id_mapping(orderbooks_data)
     
-    # 启动Market Stats WebSocket任务
-    websocket_task = asyncio.create_task(handle_market_stats_websocket())
-    logger.info("已启动Market Stats WebSocket任务")
+    # 启动Order Book WebSocket任务
+    websocket_task = asyncio.create_task(handle_order_book_websocket(config))
+    logger.info("已启动Order Book WebSocket任务")
     
-    # 启动Binance数据收集任务
-    binance_task = asyncio.create_task(handle_binance_data_collection())
-    logger.info("已启动Binance数据收集任务")
+    # 启动Binance WebSocket任务
+    binance_task = asyncio.create_task(handle_binance_websocket(config))
+    logger.info("已启动Binance WebSocket任务")
     
     # 用于定期刷新日志的计数器
     loop_counter = 0

+ 41 - 0
src/leadlag/strategy.py

@@ -99,6 +99,9 @@ class TradingStrategy:
         self.last_account_update_time = 0  # 上次更新账户信息的时间戳
         self.last_trade_time = 0        # 上次交易时间戳(开仓或平仓)
         self.position_side = None       # 持仓方向:'long' 或 'short'
+        
+        # WebSocket连接引用(由listener设置)
+        self.websocket_connection = None
 
         # 从配置文件读取Lighter相关参数
         lighter_config = config.get('lighter', {})
@@ -537,6 +540,44 @@ class TradingStrategy:
             logger.error(f"创建订单时发生错误: {str(e)}")
             return None, str(e)
 
+    def set_websocket_connection(self, websocket):
+        """设置WebSocket连接引用"""
+        self.websocket_connection = websocket
+        logger.info("WebSocket连接已设置到strategy")
+
+    async def send_websocket_message(self, message):
+        """通过WebSocket发送消息"""
+        if self.websocket_connection is None:
+            logger.warning("WebSocket连接未设置,无法发送消息")
+            return False
+        
+        try:
+            if isinstance(message, dict):
+                message = json.dumps(message)
+            
+            await self.websocket_connection.send(message)
+            logger.debug(f"WebSocket消息发送成功: {message}")
+            return True
+        except Exception as e:
+            logger.error(f"WebSocket消息发送失败: {str(e)}")
+            return False
+
+    async def subscribe_channel(self, channel):
+        """订阅指定频道"""
+        subscribe_message = {
+            "type": "subscribe",
+            "channel": channel
+        }
+        return await self.send_websocket_message(subscribe_message)
+
+    async def unsubscribe_channel(self, channel):
+        """取消订阅指定频道"""
+        unsubscribe_message = {
+            "type": "unsubscribe", 
+            "channel": channel
+        }
+        return await self.send_websocket_message(unsubscribe_message)
+
 async def main():
     from config import load_config