test_optimizations.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试优化功能的脚本
  5. """
  6. import sys
  7. import os
  8. sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'leadlag'))
  9. from decimal import Decimal
  10. from sortedcontainers import SortedDict
  11. import time
  12. # 测试1: OrderBook 性能测试
  13. print("=" * 60)
  14. print("测试1: OrderBook 性能测试")
  15. print("=" * 60)
  16. class OrderBookOld:
  17. """旧版本 - 使用普通dict"""
  18. def __init__(self):
  19. self.bids = {}
  20. self.asks = {}
  21. def update(self, prices):
  22. for price in prices:
  23. self.bids[price] = 100
  24. def get_best_bid(self):
  25. if not self.bids:
  26. return None
  27. return max(self.bids.keys())
  28. class OrderBookNew:
  29. """新版本 - 使用SortedDict"""
  30. def __init__(self):
  31. self.bids = SortedDict(lambda x: -x)
  32. self.asks = SortedDict()
  33. def update(self, prices):
  34. for price in prices:
  35. self.bids[price] = 100
  36. def get_best_bid(self):
  37. if not self.bids:
  38. return None
  39. return self.bids.keys()[0] # SortedDict中第一个就是最大的
  40. # 生成测试数据
  41. test_prices = [Decimal(str(100 + i * 0.01)) for i in range(1000)]
  42. # 测试旧版本
  43. old_book = OrderBookOld()
  44. start = time.time()
  45. for _ in range(100):
  46. old_book.update(test_prices)
  47. best = old_book.get_best_bid()
  48. old_time = time.time() - start
  49. print(f"旧版本 (普通dict): {old_time:.4f}秒")
  50. # 测试新版本
  51. new_book = OrderBookNew()
  52. start = time.time()
  53. for _ in range(100):
  54. new_book.update(test_prices)
  55. best = new_book.get_best_bid()
  56. new_time = time.time() - start
  57. print(f"新版本 (SortedDict): {new_time:.4f}秒")
  58. print(f"性能提升: {old_time/new_time:.2f}x")
  59. # 测试2: 数据库批量写入
  60. print("\n" + "=" * 60)
  61. print("测试2: 数据库批量写入功能")
  62. print("=" * 60)
  63. from database import TradingDatabase
  64. import tempfile
  65. # 创建临时数据库
  66. temp_db = os.path.join(tempfile.gettempdir(), "test_trading.db")
  67. if os.path.exists(temp_db):
  68. os.remove(temp_db)
  69. db = TradingDatabase(temp_db)
  70. # 测试缓冲区写入
  71. print("添加100条价格数据到缓冲区...")
  72. for i in range(100):
  73. db.record_price_data(
  74. symbol="TEST",
  75. binance_price=100.0 + i * 0.01,
  76. lighter_ask=100.5 + i * 0.01,
  77. lighter_bid=99.5 + i * 0.01,
  78. ask_bps=50,
  79. bid_bps=-50
  80. )
  81. print(f"缓冲区中的数据: {len(db._price_data_buffer) if hasattr(db, '_price_data_buffer') else 0} 条")
  82. # 刷新数据库
  83. count = db.flush_price_data()
  84. print(f"批量写入数据库: {count} 条")
  85. # 验证数据
  86. data = db.get_price_data(symbol="TEST")
  87. print(f"数据库中的数据: {len(data)} 条")
  88. # 清理
  89. db.close()
  90. os.remove(temp_db)
  91. # 测试3: 后台任务初始化
  92. print("\n" + "=" * 60)
  93. print("测试3: 后台任务初始化检查")
  94. print("=" * 60)
  95. from strategy import TradingStrategy
  96. from config import load_config
  97. try:
  98. config = load_config()
  99. strategy = TradingStrategy(config)
  100. # 检查后台任务相关属性
  101. print(f"✓ 策略初始化成功")
  102. print(f"✓ 账户更新间隔: {strategy.account_update_interval}秒")
  103. print(f"✓ 数据库缓冲区已初始化")
  104. # 检查后台任务方法是否存在
  105. assert hasattr(strategy, 'start_background_tasks'), "缺少 start_background_tasks 方法"
  106. assert hasattr(strategy, '_periodic_account_update'), "缺少 _periodic_account_update 方法"
  107. assert hasattr(strategy, '_periodic_db_flush'), "缺少 _periodic_db_flush 方法"
  108. print(f"✓ 所有后台任务方法已就位")
  109. except Exception as e:
  110. print(f"✗ 错误: {e}")
  111. # 测试4: OrderBook 缓存机制
  112. print("\n" + "=" * 60)
  113. print("测试4: OrderBook 缓存机制")
  114. print("=" * 60)
  115. from main import OrderBook
  116. ob = OrderBook(1)
  117. # 添加一些数据
  118. for i in range(50):
  119. ob.bids[Decimal(str(100 - i * 0.01))] = Decimal(str(10 + i))
  120. ob.asks[Decimal(str(100 + i * 0.01))] = Decimal(str(10 + i))
  121. # 第一次调用会排序
  122. start = time.time()
  123. bids1 = ob.get_sorted_bids(10)
  124. time1 = time.time() - start
  125. print(f"第一次调用 get_sorted_bids: {time1*1000:.4f}ms (包含排序)")
  126. # 第二次调用会使用缓存
  127. start = time.time()
  128. bids2 = ob.get_sorted_bids(10)
  129. time2 = time.time() - start
  130. print(f"第二次调用 get_sorted_bids: {time2*1000:.4f}ms (使用缓存)")
  131. # 更新数据后缓存失效
  132. ob.update({'bids': [{'price': '99.5', 'size': '100'}]}, 1)
  133. start = time.time()
  134. bids3 = ob.get_sorted_bids(10)
  135. time3 = time.time() - start
  136. print(f"更新后调用 get_sorted_bids: {time3*1000:.4f}ms (重新排序)")
  137. print(f"✓ 缓存机制工作正常,缓存加速: {time1/time2:.1f}x")
  138. print("\n" + "=" * 60)
  139. print("所有测试完成!")
  140. print("=" * 60)