| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 测试优化功能的脚本
- """
- import sys
- import os
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'leadlag'))
- from decimal import Decimal
- from sortedcontainers import SortedDict
- import time
- # 测试1: OrderBook 性能测试
- print("=" * 60)
- print("测试1: OrderBook 性能测试")
- print("=" * 60)
- class OrderBookOld:
- """旧版本 - 使用普通dict"""
- def __init__(self):
- self.bids = {}
- self.asks = {}
-
- def update(self, prices):
- for price in prices:
- self.bids[price] = 100
-
- def get_best_bid(self):
- if not self.bids:
- return None
- return max(self.bids.keys())
- class OrderBookNew:
- """新版本 - 使用SortedDict"""
- def __init__(self):
- self.bids = SortedDict(lambda x: -x)
- self.asks = SortedDict()
-
- def update(self, prices):
- for price in prices:
- self.bids[price] = 100
-
- def get_best_bid(self):
- if not self.bids:
- return None
- return self.bids.keys()[0] # SortedDict中第一个就是最大的
- # 生成测试数据
- test_prices = [Decimal(str(100 + i * 0.01)) for i in range(1000)]
- # 测试旧版本
- old_book = OrderBookOld()
- start = time.time()
- for _ in range(100):
- old_book.update(test_prices)
- best = old_book.get_best_bid()
- old_time = time.time() - start
- print(f"旧版本 (普通dict): {old_time:.4f}秒")
- # 测试新版本
- new_book = OrderBookNew()
- start = time.time()
- for _ in range(100):
- new_book.update(test_prices)
- best = new_book.get_best_bid()
- new_time = time.time() - start
- print(f"新版本 (SortedDict): {new_time:.4f}秒")
- print(f"性能提升: {old_time/new_time:.2f}x")
- # 测试2: 数据库批量写入
- print("\n" + "=" * 60)
- print("测试2: 数据库批量写入功能")
- print("=" * 60)
- from database import TradingDatabase
- import tempfile
- # 创建临时数据库
- temp_db = os.path.join(tempfile.gettempdir(), "test_trading.db")
- if os.path.exists(temp_db):
- os.remove(temp_db)
- db = TradingDatabase(temp_db)
- # 测试缓冲区写入
- print("添加100条价格数据到缓冲区...")
- for i in range(100):
- db.record_price_data(
- symbol="TEST",
- binance_price=100.0 + i * 0.01,
- lighter_ask=100.5 + i * 0.01,
- lighter_bid=99.5 + i * 0.01,
- ask_bps=50,
- bid_bps=-50
- )
- print(f"缓冲区中的数据: {len(db._price_data_buffer) if hasattr(db, '_price_data_buffer') else 0} 条")
- # 刷新数据库
- count = db.flush_price_data()
- print(f"批量写入数据库: {count} 条")
- # 验证数据
- data = db.get_price_data(symbol="TEST")
- print(f"数据库中的数据: {len(data)} 条")
- # 清理
- db.close()
- os.remove(temp_db)
- # 测试3: 后台任务初始化
- print("\n" + "=" * 60)
- print("测试3: 后台任务初始化检查")
- print("=" * 60)
- from strategy import TradingStrategy
- from config import load_config
- try:
- config = load_config()
- strategy = TradingStrategy(config)
- # 检查后台任务相关属性
- print(f"✓ 策略初始化成功")
- print(f"✓ 账户更新间隔: {strategy.account_update_interval}秒")
- print(f"✓ 数据库缓冲区已初始化")
- # 检查后台任务方法是否存在
- assert hasattr(strategy, 'start_background_tasks'), "缺少 start_background_tasks 方法"
- assert hasattr(strategy, '_periodic_account_update'), "缺少 _periodic_account_update 方法"
- assert hasattr(strategy, '_periodic_db_flush'), "缺少 _periodic_db_flush 方法"
- print(f"✓ 所有后台任务方法已就位")
- except Exception as e:
- print(f"✗ 错误: {e}")
- # 测试4: OrderBook 缓存机制
- print("\n" + "=" * 60)
- print("测试4: OrderBook 缓存机制")
- print("=" * 60)
- from main import OrderBook
- ob = OrderBook(1)
- # 添加一些数据
- for i in range(50):
- ob.bids[Decimal(str(100 - i * 0.01))] = Decimal(str(10 + i))
- ob.asks[Decimal(str(100 + i * 0.01))] = Decimal(str(10 + i))
- # 第一次调用会排序
- start = time.time()
- bids1 = ob.get_sorted_bids(10)
- time1 = time.time() - start
- print(f"第一次调用 get_sorted_bids: {time1*1000:.4f}ms (包含排序)")
- # 第二次调用会使用缓存
- start = time.time()
- bids2 = ob.get_sorted_bids(10)
- time2 = time.time() - start
- print(f"第二次调用 get_sorted_bids: {time2*1000:.4f}ms (使用缓存)")
- # 更新数据后缓存失效
- ob.update({'bids': [{'price': '99.5', 'size': '100'}]}, 1)
- start = time.time()
- bids3 = ob.get_sorted_bids(10)
- time3 = time.time() - start
- print(f"更新后调用 get_sorted_bids: {time3*1000:.4f}ms (重新排序)")
- print(f"✓ 缓存机制工作正常,缓存加速: {time1/time2:.1f}x")
- print("\n" + "=" * 60)
- print("所有测试完成!")
- print("=" * 60)
|