#!/usr/bin/env python # -*- coding: utf-8 -*- """ 测试优化功能的脚本 """ import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'leadlag')) from decimal import Decimal from sortedcontainers import SortedDict import time # 测试1: OrderBook 性能测试 print("=" * 60) print("测试1: OrderBook 性能测试") print("=" * 60) class OrderBookOld: """旧版本 - 使用普通dict""" def __init__(self): self.bids = {} self.asks = {} def update(self, prices): for price in prices: self.bids[price] = 100 def get_best_bid(self): if not self.bids: return None return max(self.bids.keys()) class OrderBookNew: """新版本 - 使用SortedDict""" def __init__(self): self.bids = SortedDict(lambda x: -x) self.asks = SortedDict() def update(self, prices): for price in prices: self.bids[price] = 100 def get_best_bid(self): if not self.bids: return None return self.bids.keys()[0] # SortedDict中第一个就是最大的 # 生成测试数据 test_prices = [Decimal(str(100 + i * 0.01)) for i in range(1000)] # 测试旧版本 old_book = OrderBookOld() start = time.time() for _ in range(100): old_book.update(test_prices) best = old_book.get_best_bid() old_time = time.time() - start print(f"旧版本 (普通dict): {old_time:.4f}秒") # 测试新版本 new_book = OrderBookNew() start = time.time() for _ in range(100): new_book.update(test_prices) best = new_book.get_best_bid() new_time = time.time() - start print(f"新版本 (SortedDict): {new_time:.4f}秒") print(f"性能提升: {old_time/new_time:.2f}x") # 测试2: 数据库批量写入 print("\n" + "=" * 60) print("测试2: 数据库批量写入功能") print("=" * 60) from database import TradingDatabase import tempfile # 创建临时数据库 temp_db = os.path.join(tempfile.gettempdir(), "test_trading.db") if os.path.exists(temp_db): os.remove(temp_db) db = TradingDatabase(temp_db) # 测试缓冲区写入 print("添加100条价格数据到缓冲区...") for i in range(100): db.record_price_data( symbol="TEST", binance_price=100.0 + i * 0.01, lighter_ask=100.5 + i * 0.01, lighter_bid=99.5 + i * 0.01, ask_bps=50, bid_bps=-50 ) print(f"缓冲区中的数据: {len(db._price_data_buffer) if hasattr(db, '_price_data_buffer') else 0} 条") # 刷新数据库 count = db.flush_price_data() print(f"批量写入数据库: {count} 条") # 验证数据 data = db.get_price_data(symbol="TEST") print(f"数据库中的数据: {len(data)} 条") # 清理 db.close() os.remove(temp_db) # 测试3: 后台任务初始化 print("\n" + "=" * 60) print("测试3: 后台任务初始化检查") print("=" * 60) from strategy import TradingStrategy from config import load_config try: config = load_config() strategy = TradingStrategy(config) # 检查后台任务相关属性 print(f"✓ 策略初始化成功") print(f"✓ 账户更新间隔: {strategy.account_update_interval}秒") print(f"✓ 数据库缓冲区已初始化") # 检查后台任务方法是否存在 assert hasattr(strategy, 'start_background_tasks'), "缺少 start_background_tasks 方法" assert hasattr(strategy, '_periodic_account_update'), "缺少 _periodic_account_update 方法" assert hasattr(strategy, '_periodic_db_flush'), "缺少 _periodic_db_flush 方法" print(f"✓ 所有后台任务方法已就位") except Exception as e: print(f"✗ 错误: {e}") # 测试4: OrderBook 缓存机制 print("\n" + "=" * 60) print("测试4: OrderBook 缓存机制") print("=" * 60) from main import OrderBook ob = OrderBook(1) # 添加一些数据 for i in range(50): ob.bids[Decimal(str(100 - i * 0.01))] = Decimal(str(10 + i)) ob.asks[Decimal(str(100 + i * 0.01))] = Decimal(str(10 + i)) # 第一次调用会排序 start = time.time() bids1 = ob.get_sorted_bids(10) time1 = time.time() - start print(f"第一次调用 get_sorted_bids: {time1*1000:.4f}ms (包含排序)") # 第二次调用会使用缓存 start = time.time() bids2 = ob.get_sorted_bids(10) time2 = time.time() - start print(f"第二次调用 get_sorted_bids: {time2*1000:.4f}ms (使用缓存)") # 更新数据后缓存失效 ob.update({'bids': [{'price': '99.5', 'size': '100'}]}, 1) start = time.time() bids3 = ob.get_sorted_bids(10) time3 = time.time() - start print(f"更新后调用 get_sorted_bids: {time3*1000:.4f}ms (重新排序)") print(f"✓ 缓存机制工作正常,缓存加速: {time1/time2:.1f}x") print("\n" + "=" * 60) print("所有测试完成!") print("=" * 60)