|
|
@@ -0,0 +1,390 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+数据库模块
|
|
|
+用于记录价格数据和交易事件,支持后续的数据分析和可视化
|
|
|
+"""
|
|
|
+
|
|
|
+import sqlite3
|
|
|
+import logging
|
|
|
+import os
|
|
|
+from datetime import datetime
|
|
|
+from typing import Optional, List, Dict, Any
|
|
|
+import json
|
|
|
+
|
|
|
+logger = logging.getLogger("database")
|
|
|
+
|
|
|
+class TradingDatabase:
|
|
|
+ """交易数据库类,负责记录价格数据和交易事件"""
|
|
|
+
|
|
|
+ def __init__(self, db_path: str = None):
|
|
|
+ """
|
|
|
+ 初始化数据库连接
|
|
|
+
|
|
|
+ Args:
|
|
|
+ db_path: 数据库文件路径,如果为None则使用默认路径
|
|
|
+ """
|
|
|
+ if db_path is None:
|
|
|
+ # 使用项目根目录下的data文件夹
|
|
|
+ script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+ project_root = os.path.dirname(os.path.dirname(script_dir))
|
|
|
+ data_dir = os.path.join(project_root, "data")
|
|
|
+
|
|
|
+ # 确保data目录存在
|
|
|
+ os.makedirs(data_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # 使用日期作为数据库文件名
|
|
|
+ today = datetime.now().strftime("%Y%m%d")
|
|
|
+ db_path = os.path.join(data_dir, f"trading_data_{today}.db")
|
|
|
+
|
|
|
+ self.db_path = db_path
|
|
|
+ self.connection = None
|
|
|
+ self._init_database()
|
|
|
+
|
|
|
+ def _init_database(self):
|
|
|
+ """初始化数据库连接和表结构"""
|
|
|
+ try:
|
|
|
+ self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
|
+ self.connection.row_factory = sqlite3.Row # 使结果可以通过列名访问
|
|
|
+ self._create_tables()
|
|
|
+ logger.info(f"数据库初始化成功: {self.db_path}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"数据库初始化失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def _create_tables(self):
|
|
|
+ """创建数据库表结构"""
|
|
|
+ cursor = self.connection.cursor()
|
|
|
+
|
|
|
+ # 价格数据表
|
|
|
+ cursor.execute("""
|
|
|
+ CREATE TABLE IF NOT EXISTS price_data (
|
|
|
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
+ timestamp REAL NOT NULL,
|
|
|
+ datetime_str TEXT NOT NULL,
|
|
|
+ symbol TEXT NOT NULL,
|
|
|
+ lighter_price REAL,
|
|
|
+ binance_price REAL,
|
|
|
+ spread_bps REAL,
|
|
|
+ lighter_bid REAL,
|
|
|
+ lighter_ask REAL,
|
|
|
+ lighter_bid_size REAL,
|
|
|
+ lighter_ask_size REAL,
|
|
|
+ binance_volume REAL,
|
|
|
+ raw_data TEXT,
|
|
|
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
|
+ )
|
|
|
+ """)
|
|
|
+
|
|
|
+ # 交易事件表
|
|
|
+ cursor.execute("""
|
|
|
+ CREATE TABLE IF NOT EXISTS trading_events (
|
|
|
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
+ timestamp REAL NOT NULL,
|
|
|
+ datetime_str TEXT NOT NULL,
|
|
|
+ symbol TEXT NOT NULL,
|
|
|
+ event_type TEXT NOT NULL, -- 'open_long', 'open_short', 'close_long', 'close_short'
|
|
|
+ price REAL NOT NULL,
|
|
|
+ quantity REAL NOT NULL,
|
|
|
+ side TEXT NOT NULL, -- 'long', 'short'
|
|
|
+ strategy_state TEXT,
|
|
|
+ spread_bps REAL,
|
|
|
+ lighter_price REAL,
|
|
|
+ binance_price REAL,
|
|
|
+ order_id TEXT,
|
|
|
+ tx_hash TEXT,
|
|
|
+ success BOOLEAN,
|
|
|
+ error_message TEXT,
|
|
|
+ metadata TEXT, -- JSON格式的额外信息
|
|
|
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
|
+ )
|
|
|
+ """)
|
|
|
+
|
|
|
+ # 创建索引以提高查询性能
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_price_timestamp ON price_data(timestamp)")
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_price_symbol ON price_data(symbol)")
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_timestamp ON trading_events(timestamp)")
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_symbol ON trading_events(symbol)")
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON trading_events(event_type)")
|
|
|
+
|
|
|
+ self.connection.commit()
|
|
|
+ logger.info("数据库表结构创建完成")
|
|
|
+
|
|
|
+ def record_price_data(self,
|
|
|
+ symbol: str,
|
|
|
+ lighter_price: Optional[float] = None,
|
|
|
+ binance_price: Optional[float] = None,
|
|
|
+ spread_bps: Optional[float] = None,
|
|
|
+ lighter_bid: Optional[float] = None,
|
|
|
+ lighter_ask: Optional[float] = None,
|
|
|
+ lighter_bid_size: Optional[float] = None,
|
|
|
+ lighter_ask_size: Optional[float] = None,
|
|
|
+ binance_volume: Optional[float] = None,
|
|
|
+ raw_data: Optional[Dict] = None):
|
|
|
+ """
|
|
|
+ 记录价格数据
|
|
|
+
|
|
|
+ Args:
|
|
|
+ symbol: 交易对符号
|
|
|
+ lighter_price: Lighter价格
|
|
|
+ binance_price: Binance价格
|
|
|
+ spread_bps: 价差(基点)
|
|
|
+ lighter_bid: Lighter买价
|
|
|
+ lighter_ask: Lighter卖价
|
|
|
+ lighter_bid_size: Lighter买量
|
|
|
+ lighter_ask_size: Lighter卖量
|
|
|
+ binance_volume: Binance成交量
|
|
|
+ raw_data: 原始数据(字典格式)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ timestamp = datetime.now().timestamp()
|
|
|
+ datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
|
|
+
|
|
|
+ cursor = self.connection.cursor()
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO price_data (
|
|
|
+ timestamp, datetime_str, symbol, lighter_price, binance_price,
|
|
|
+ spread_bps, lighter_bid, lighter_ask, lighter_bid_size,
|
|
|
+ lighter_ask_size, binance_volume, raw_data
|
|
|
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
+ """, (
|
|
|
+ timestamp, datetime_str, symbol, lighter_price, binance_price,
|
|
|
+ spread_bps, lighter_bid, lighter_ask, lighter_bid_size,
|
|
|
+ lighter_ask_size, binance_volume,
|
|
|
+ json.dumps(raw_data) if raw_data else None
|
|
|
+ ))
|
|
|
+
|
|
|
+ self.connection.commit()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"记录价格数据失败: {e}")
|
|
|
+
|
|
|
+ def record_trading_event(self,
|
|
|
+ symbol: str,
|
|
|
+ event_type: str,
|
|
|
+ price: float,
|
|
|
+ quantity: float,
|
|
|
+ side: str,
|
|
|
+ strategy_state: Optional[str] = None,
|
|
|
+ spread_bps: Optional[float] = None,
|
|
|
+ lighter_price: Optional[float] = None,
|
|
|
+ binance_price: Optional[float] = None,
|
|
|
+ order_id: Optional[str] = None,
|
|
|
+ tx_hash: Optional[str] = None,
|
|
|
+ success: Optional[bool] = None,
|
|
|
+ error_message: Optional[str] = None,
|
|
|
+ metadata: Optional[Dict] = None):
|
|
|
+ """
|
|
|
+ 记录交易事件
|
|
|
+
|
|
|
+ Args:
|
|
|
+ symbol: 交易对符号
|
|
|
+ event_type: 事件类型 ('open_long', 'open_short', 'close_long', 'close_short')
|
|
|
+ price: 交易价格
|
|
|
+ quantity: 交易数量
|
|
|
+ side: 交易方向 ('long', 'short')
|
|
|
+ strategy_state: 策略状态
|
|
|
+ spread_bps: 当时的价差
|
|
|
+ lighter_price: 当时的Lighter价格
|
|
|
+ binance_price: 当时的Binance价格
|
|
|
+ order_id: 订单ID
|
|
|
+ tx_hash: 交易哈希
|
|
|
+ success: 是否成功
|
|
|
+ error_message: 错误信息
|
|
|
+ metadata: 额外的元数据
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ timestamp = datetime.now().timestamp()
|
|
|
+ datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
|
|
+
|
|
|
+ cursor = self.connection.cursor()
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO trading_events (
|
|
|
+ timestamp, datetime_str, symbol, event_type, price, quantity,
|
|
|
+ side, strategy_state, spread_bps, lighter_price, binance_price,
|
|
|
+ order_id, tx_hash, success, error_message, metadata
|
|
|
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
+ """, (
|
|
|
+ timestamp, datetime_str, symbol, event_type, price, quantity,
|
|
|
+ side, strategy_state, spread_bps, lighter_price, binance_price,
|
|
|
+ order_id, tx_hash, success, error_message,
|
|
|
+ json.dumps(metadata) if metadata else None
|
|
|
+ ))
|
|
|
+
|
|
|
+ self.connection.commit()
|
|
|
+ logger.info(f"记录交易事件: {event_type} {side} {quantity}@{price}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"记录交易事件失败: {e}")
|
|
|
+
|
|
|
+ def get_price_data(self,
|
|
|
+ symbol: Optional[str] = None,
|
|
|
+ start_time: Optional[float] = None,
|
|
|
+ end_time: Optional[float] = None,
|
|
|
+ limit: Optional[int] = None) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 获取价格数据
|
|
|
+
|
|
|
+ Args:
|
|
|
+ symbol: 交易对符号
|
|
|
+ start_time: 开始时间戳
|
|
|
+ end_time: 结束时间戳
|
|
|
+ limit: 限制返回条数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 价格数据列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ cursor = self.connection.cursor()
|
|
|
+
|
|
|
+ query = "SELECT * FROM price_data WHERE 1=1"
|
|
|
+ params = []
|
|
|
+
|
|
|
+ if symbol:
|
|
|
+ query += " AND symbol = ?"
|
|
|
+ params.append(symbol)
|
|
|
+
|
|
|
+ if start_time:
|
|
|
+ query += " AND timestamp >= ?"
|
|
|
+ params.append(start_time)
|
|
|
+
|
|
|
+ if end_time:
|
|
|
+ query += " AND timestamp <= ?"
|
|
|
+ params.append(end_time)
|
|
|
+
|
|
|
+ query += " ORDER BY timestamp DESC"
|
|
|
+
|
|
|
+ if limit:
|
|
|
+ query += " LIMIT ?"
|
|
|
+ params.append(limit)
|
|
|
+
|
|
|
+ cursor.execute(query, params)
|
|
|
+ rows = cursor.fetchall()
|
|
|
+
|
|
|
+ return [dict(row) for row in rows]
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取价格数据失败: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ def get_trading_events(self,
|
|
|
+ symbol: Optional[str] = None,
|
|
|
+ event_type: Optional[str] = None,
|
|
|
+ start_time: Optional[float] = None,
|
|
|
+ end_time: Optional[float] = None,
|
|
|
+ limit: Optional[int] = None) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 获取交易事件
|
|
|
+
|
|
|
+ Args:
|
|
|
+ symbol: 交易对符号
|
|
|
+ event_type: 事件类型
|
|
|
+ start_time: 开始时间戳
|
|
|
+ end_time: 结束时间戳
|
|
|
+ limit: 限制返回条数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 交易事件列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ cursor = self.connection.cursor()
|
|
|
+
|
|
|
+ query = "SELECT * FROM trading_events WHERE 1=1"
|
|
|
+ params = []
|
|
|
+
|
|
|
+ if symbol:
|
|
|
+ query += " AND symbol = ?"
|
|
|
+ params.append(symbol)
|
|
|
+
|
|
|
+ if event_type:
|
|
|
+ query += " AND event_type = ?"
|
|
|
+ params.append(event_type)
|
|
|
+
|
|
|
+ if start_time:
|
|
|
+ query += " AND timestamp >= ?"
|
|
|
+ params.append(start_time)
|
|
|
+
|
|
|
+ if end_time:
|
|
|
+ query += " AND timestamp <= ?"
|
|
|
+ params.append(end_time)
|
|
|
+
|
|
|
+ query += " ORDER BY timestamp DESC"
|
|
|
+
|
|
|
+ if limit:
|
|
|
+ query += " LIMIT ?"
|
|
|
+ params.append(limit)
|
|
|
+
|
|
|
+ cursor.execute(query, params)
|
|
|
+ rows = cursor.fetchall()
|
|
|
+
|
|
|
+ return [dict(row) for row in rows]
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取交易事件失败: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ def get_statistics(self, symbol: Optional[str] = None) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 获取统计信息
|
|
|
+
|
|
|
+ Args:
|
|
|
+ symbol: 交易对符号
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 统计信息字典
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ cursor = self.connection.cursor()
|
|
|
+ stats = {}
|
|
|
+
|
|
|
+ # 价格数据统计
|
|
|
+ query = "SELECT COUNT(*) as count FROM price_data"
|
|
|
+ params = []
|
|
|
+ if symbol:
|
|
|
+ query += " WHERE symbol = ?"
|
|
|
+ params.append(symbol)
|
|
|
+
|
|
|
+ cursor.execute(query, params)
|
|
|
+ stats['price_data_count'] = cursor.fetchone()[0]
|
|
|
+
|
|
|
+ # 交易事件统计
|
|
|
+ query = "SELECT COUNT(*) as count FROM trading_events"
|
|
|
+ params = []
|
|
|
+ if symbol:
|
|
|
+ query += " WHERE symbol = ?"
|
|
|
+ params.append(symbol)
|
|
|
+
|
|
|
+ cursor.execute(query, params)
|
|
|
+ stats['trading_events_count'] = cursor.fetchone()[0]
|
|
|
+
|
|
|
+ # 按事件类型统计
|
|
|
+ query = "SELECT event_type, COUNT(*) as count FROM trading_events"
|
|
|
+ params = []
|
|
|
+ if symbol:
|
|
|
+ query += " WHERE symbol = ?"
|
|
|
+ params.append(symbol)
|
|
|
+ query += " GROUP BY event_type"
|
|
|
+
|
|
|
+ cursor.execute(query, params)
|
|
|
+ event_stats = {}
|
|
|
+ for row in cursor.fetchall():
|
|
|
+ event_stats[row[0]] = row[1]
|
|
|
+ stats['events_by_type'] = event_stats
|
|
|
+
|
|
|
+ return stats
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取统计信息失败: {e}")
|
|
|
+ return {}
|
|
|
+
|
|
|
+ def close(self):
|
|
|
+ """关闭数据库连接"""
|
|
|
+ if self.connection:
|
|
|
+ self.connection.close()
|
|
|
+ logger.info("数据库连接已关闭")
|
|
|
+
|
|
|
+ def __enter__(self):
|
|
|
+ return self
|
|
|
+
|
|
|
+ def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
+ self.close()
|