Ver código fonte

feat(配置): 重构配置管理模块并集成到交易策略中

将配置管理功能从策略类中分离,创建独立的配置管理模块
策略类现在通过构造函数接收配置对象,提高了模块化和可测试性
添加默认配置值以增强健壮性
skyfffire 5 dias atrás
pai
commit
15e168edaf
3 arquivos alterados com 150 adições e 36 exclusões
  1. 115 0
      src/leadlag/config.py
  2. 11 2
      src/leadlag/listener.py
  3. 24 34
      src/leadlag/strategy.py

+ 115 - 0
src/leadlag/config.py

@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+配置管理模块
+负责加载和管理应用程序配置
+"""
+
+import os
+import toml
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class Config:
+    """配置管理类"""
+    
+    def __init__(self, config_path=None):
+        """
+        初始化配置管理器
+        
+        Args:
+            config_path: 配置文件路径,如果为None则使用默认路径
+        """
+        if config_path is None:
+            # 获取项目根目录(向上两级目录)
+            current_dir = os.path.dirname(os.path.abspath(__file__))
+            project_root = os.path.dirname(os.path.dirname(current_dir))
+            config_path = os.path.join(project_root, 'config.toml')
+        
+        self.config_path = config_path
+        self._config_data = None
+        self.load_config()
+    
+    def load_config(self):
+        """加载配置文件"""
+        try:
+            with open(self.config_path, 'r', encoding='utf-8') as f:
+                self._config_data = toml.load(f)
+            logger.info(f"配置文件加载成功: {self.config_path}")
+        except FileNotFoundError:
+            logger.error(f"配置文件未找到: {self.config_path}")
+            raise
+        except Exception as e:
+            logger.error(f"配置文件加载失败: {str(e)}")
+            raise
+    
+    def reload_config(self):
+        """重新加载配置文件"""
+        self.load_config()
+        logger.info("配置文件已重新加载")
+    
+    def get_strategy_config(self):
+        """获取策略配置"""
+        return self._config_data.get('strategy', {})
+    
+    def get_lighter_config(self):
+        """获取Lighter配置"""
+        return self._config_data.get('lighter', {})
+    
+    def get_config(self, section=None):
+        """
+        获取配置数据
+        
+        Args:
+            section: 配置节名称,如果为None则返回完整配置
+            
+        Returns:
+            配置数据字典
+        """
+        if section is None:
+            return self._config_data
+        return self._config_data.get(section, {})
+    
+    def get_value(self, section, key, default=None):
+        """
+        获取指定配置值
+        
+        Args:
+            section: 配置节名称
+            key: 配置键名
+            default: 默认值
+            
+        Returns:
+            配置值
+        """
+        section_data = self._config_data.get(section, {})
+        return section_data.get(key, default)
+
+
+def load_config(config_path=None):
+    """
+    便捷函数:加载配置文件
+    
+    Args:
+        config_path: 配置文件路径
+        
+    Returns:
+        配置数据字典
+    """
+    config_manager = Config(config_path)
+    return config_manager.get_config()
+
+
+def create_config_manager(config_path=None):
+    """
+    便捷函数:创建配置管理器实例
+    
+    Args:
+        config_path: 配置文件路径
+        
+    Returns:
+        Config实例
+    """
+    return Config(config_path)

+ 11 - 2
src/leadlag/listener.py

@@ -18,6 +18,7 @@ import os
 import requests
 from datetime import datetime
 from strategy import TradingStrategy
+from config import load_config
 
 # 配置日志
 # 创建logs目录(如果不存在)
@@ -365,8 +366,16 @@ async def main():
     
     logger.info("正在启动行情数据记录器")
     
-    # 初始化策略
-    trading_strategy = TradingStrategy()
+    # 加载配置文件
+    try:
+        config = load_config()
+        logger.info("配置文件加载成功")
+    except Exception as e:
+        logger.error(f"配置文件加载失败: {str(e)}")
+        return
+    
+    # 初始化策略,传入配置
+    trading_strategy = TradingStrategy(config)
     logger.info("交易策略已初始化")
     
     # 添加每小时打印匹配交易对数量的变量

+ 24 - 34
src/leadlag/strategy.py

@@ -10,7 +10,6 @@ from enum import Enum
 import os
 import lighter
 import time
-import toml
 
 
 # 配置日志
@@ -62,26 +61,6 @@ if not logger.handlers:
 logger.propagate = False
 
 
-def load_config():
-    """加载配置文件"""
-    # 获取项目根目录(向上两级目录)
-    current_dir = os.path.dirname(os.path.abspath(__file__))
-    project_root = os.path.dirname(os.path.dirname(current_dir))
-    config_path = os.path.join(project_root, 'config.toml')
-    
-    try:
-        with open(config_path, 'r', encoding='utf-8') as f:
-            config = toml.load(f)
-        logger.info(f"配置文件加载成功: {config_path}")
-        return config
-    except FileNotFoundError:
-        logger.error(f"配置文件未找到: {config_path}")
-        raise
-    except Exception as e:
-        logger.error(f"配置文件加载失败: {str(e)}")
-        raise
-
-
 class StrategyState(Enum):
     """策略状态枚举"""
     WAITING_INIT = 1  # 等待初始化
@@ -97,18 +76,24 @@ class StrategyState(Enum):
 class TradingStrategy:
     """交易策略类"""
     
-    def __init__(self):
-        """初始化策略"""
-        # 加载配置文件
-        self.config = load_config()
+    def __init__(self, config):
+        """
+        初始化策略
+        
+        Args:
+            config: 配置字典,包含strategy和lighter两个部分
+        """
+        # 保存传入的配置
+        self.config = config
         
         self.state = StrategyState.WAITING_INIT
         self.current_position = None    # 当前持仓信息
         
         # 从配置文件读取策略参数
-        self.entry_price_bps = self.config['strategy']['entry_price_bps']        # 入场时的价差(单位:bps)
-        self.target_symbol = self.config['strategy']['target_symbol']     # 目标交易对
-        self.trade_quantity = self.config['strategy']['trade_quantity']         # 交易数量(买卖数量)
+        strategy_config = config.get('strategy', {})
+        self.entry_price_bps = strategy_config.get('entry_price_bps', 20)        # 入场时的价差(单位:bps)
+        self.target_symbol = strategy_config.get('target_symbol', '1000FLOKI')     # 目标交易对
+        self.trade_quantity = strategy_config.get('trade_quantity', 100)         # 交易数量(买卖数量)
         
         self.account_info = None        # 存储账户信息
         self.last_account_update_time = 0  # 上次更新账户信息的时间戳
@@ -116,15 +101,16 @@ class TradingStrategy:
         self.position_side = None       # 持仓方向:'long' 或 'short'
 
         # 从配置文件读取Lighter相关参数
-        self.account_index = self.config['lighter']['account_index']
-        self.api_key_index = self.config['lighter']['api_key_index']
+        lighter_config = config.get('lighter', {})
+        self.account_index = lighter_config.get('account_index', 318163)
+        self.api_key_index = lighter_config.get('api_key_index', 0)
 
         self.api_client = lighter.ApiClient()
         self.account_api = lighter.AccountApi(self.api_client)
         self.transaction_api = lighter.TransactionApi(self.api_client)
         self.signer_client = lighter.SignerClient(  
-            url=self.config['lighter']['url'],  
-            private_key=self.config['lighter']['private_key'],  
+            url=lighter_config.get('url', 'https://mainnet.zklighter.elliot.ai'),  
+            private_key=lighter_config.get('private_key', ''),  
             account_index=self.account_index,
             api_key_index=self.api_key_index
         )
@@ -551,8 +537,12 @@ class TradingStrategy:
             logger.error(f"创建订单时发生错误: {str(e)}")
             return None, str(e)
 
-async def main():    
-    strategy = TradingStrategy()
+async def main():
+    from config import load_config
+    
+    # 加载配置文件
+    config = load_config()
+    strategy = TradingStrategy(config)
     # account = await strategy.account_api.account(by="index", value=f"{strategy.account_index}")
 
     # [AccountPosition(market_id=3, symbol='DOGE', initial_margin_fraction='10.00', open_order_count=0, pending_order_count=0, position_tied_order_count=0, sign=1, position='1', avg_entry_price='0.194368', position_value='0.194360', unrealized_pnl='-0.000008', realized_pnl='0.000000', liquidation_price='0', total_funding_paid_out=None, margin_mode=0, allocated_margin='0.000000', additional_properties={})]