| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 数据库模块
- 用于记录价格数据和交易事件,支持后续的数据分析和可视化
- """
- import sqlite3
- import logging
- import os
- from datetime import datetime
- from typing import Optional, List, Dict, Any
- import json
- logger = logging.getLogger("database")
- class TradingDatabase:
- """交易数据库类,负责记录价格数据和交易事件"""
-
- def __init__(self, db_path: str = None):
- """
- 初始化数据库连接
-
- Args:
- db_path: 数据库文件路径,如果为None则使用默认路径
- """
- if db_path is None:
- # 使用项目根目录下的data文件夹
- script_dir = os.path.dirname(os.path.abspath(__file__))
- project_root = os.path.dirname(os.path.dirname(script_dir))
- data_dir = os.path.join(project_root, "data")
-
- # 确保data目录存在
- os.makedirs(data_dir, exist_ok=True)
-
- # 使用日期作为数据库文件名
- today = datetime.now().strftime("%Y%m%d")
- db_path = os.path.join(data_dir, f"trading_data_{today}.db")
-
- self.db_path = db_path
- self.connection = None
- self._init_database()
-
- def _init_database(self):
- """初始化数据库连接和表结构"""
- try:
- self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
- self.connection.row_factory = sqlite3.Row # 使结果可以通过列名访问
- self._create_tables()
- logger.info(f"数据库初始化成功: {self.db_path}")
- except Exception as e:
- logger.error(f"数据库初始化失败: {e}")
- raise
-
- def _create_tables(self):
- """创建数据库表结构"""
- cursor = self.connection.cursor()
-
- # 价格数据表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS price_data (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- timestamp REAL NOT NULL,
- datetime_str TEXT NOT NULL,
- symbol TEXT NOT NULL,
- lighter_price REAL,
- binance_price REAL,
- spread_bps REAL,
- lighter_bid REAL,
- lighter_ask REAL,
- lighter_bid_size REAL,
- lighter_ask_size REAL,
- binance_volume REAL,
- raw_data TEXT,
- created_at DATETIME DEFAULT CURRENT_TIMESTAMP
- )
- """)
-
- # 交易事件表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS trading_events (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- timestamp REAL NOT NULL,
- datetime_str TEXT NOT NULL,
- symbol TEXT NOT NULL,
- event_type TEXT NOT NULL, -- 'open_long', 'open_short', 'close_long', 'close_short'
- price REAL NOT NULL,
- quantity REAL NOT NULL,
- side TEXT NOT NULL, -- 'long', 'short'
- strategy_state TEXT,
- spread_bps REAL,
- lighter_price REAL,
- binance_price REAL,
- order_id TEXT,
- tx_hash TEXT,
- success BOOLEAN,
- error_message TEXT,
- metadata TEXT, -- JSON格式的额外信息
- created_at DATETIME DEFAULT CURRENT_TIMESTAMP
- )
- """)
-
- # 创建索引以提高查询性能
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_price_timestamp ON price_data(timestamp)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_price_symbol ON price_data(symbol)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_timestamp ON trading_events(timestamp)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_symbol ON trading_events(symbol)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON trading_events(event_type)")
-
- self.connection.commit()
- logger.info("数据库表结构创建完成")
-
- def record_price_data(self,
- symbol: str,
- lighter_price: Optional[float] = None,
- binance_price: Optional[float] = None,
- spread_bps: Optional[float] = None,
- lighter_bid: Optional[float] = None,
- lighter_ask: Optional[float] = None,
- lighter_bid_size: Optional[float] = None,
- lighter_ask_size: Optional[float] = None,
- binance_volume: Optional[float] = None,
- raw_data: Optional[Dict] = None):
- """
- 记录价格数据
-
- Args:
- symbol: 交易对符号
- lighter_price: Lighter价格
- binance_price: Binance价格
- spread_bps: 价差(基点)
- lighter_bid: Lighter买价
- lighter_ask: Lighter卖价
- lighter_bid_size: Lighter买量
- lighter_ask_size: Lighter卖量
- binance_volume: Binance成交量
- raw_data: 原始数据(字典格式)
- """
- try:
- timestamp = datetime.now().timestamp()
- datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
-
- cursor = self.connection.cursor()
- cursor.execute("""
- INSERT INTO price_data (
- timestamp, datetime_str, symbol, lighter_price, binance_price,
- spread_bps, lighter_bid, lighter_ask, lighter_bid_size,
- lighter_ask_size, binance_volume, raw_data
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- timestamp, datetime_str, symbol, lighter_price, binance_price,
- spread_bps, lighter_bid, lighter_ask, lighter_bid_size,
- lighter_ask_size, binance_volume,
- json.dumps(raw_data) if raw_data else None
- ))
-
- self.connection.commit()
-
- except Exception as e:
- logger.error(f"记录价格数据失败: {e}")
-
- def record_trading_event(self,
- symbol: str,
- event_type: str,
- price: float,
- quantity: float,
- side: str,
- strategy_state: Optional[str] = None,
- spread_bps: Optional[float] = None,
- lighter_price: Optional[float] = None,
- binance_price: Optional[float] = None,
- order_id: Optional[str] = None,
- tx_hash: Optional[str] = None,
- success: Optional[bool] = None,
- error_message: Optional[str] = None,
- metadata: Optional[Dict] = None):
- """
- 记录交易事件
-
- Args:
- symbol: 交易对符号
- event_type: 事件类型 ('open_long', 'open_short', 'close_long', 'close_short')
- price: 交易价格
- quantity: 交易数量
- side: 交易方向 ('long', 'short')
- strategy_state: 策略状态
- spread_bps: 当时的价差
- lighter_price: 当时的Lighter价格
- binance_price: 当时的Binance价格
- order_id: 订单ID
- tx_hash: 交易哈希
- success: 是否成功
- error_message: 错误信息
- metadata: 额外的元数据
- """
- try:
- timestamp = datetime.now().timestamp()
- datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
-
- cursor = self.connection.cursor()
- cursor.execute("""
- INSERT INTO trading_events (
- timestamp, datetime_str, symbol, event_type, price, quantity,
- side, strategy_state, spread_bps, lighter_price, binance_price,
- order_id, tx_hash, success, error_message, metadata
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- timestamp, datetime_str, symbol, event_type, price, quantity,
- side, strategy_state, spread_bps, lighter_price, binance_price,
- order_id, tx_hash, success, error_message,
- json.dumps(metadata) if metadata else None
- ))
-
- self.connection.commit()
- logger.info(f"记录交易事件: {event_type} {side} {quantity}@{price}")
-
- except Exception as e:
- logger.error(f"记录交易事件失败: {e}")
-
- def get_price_data(self,
- symbol: Optional[str] = None,
- start_time: Optional[float] = None,
- end_time: Optional[float] = None,
- limit: Optional[int] = None) -> List[Dict]:
- """
- 获取价格数据
-
- Args:
- symbol: 交易对符号
- start_time: 开始时间戳
- end_time: 结束时间戳
- limit: 限制返回条数
-
- Returns:
- 价格数据列表
- """
- try:
- cursor = self.connection.cursor()
-
- query = "SELECT * FROM price_data WHERE 1=1"
- params = []
-
- if symbol:
- query += " AND symbol = ?"
- params.append(symbol)
-
- if start_time:
- query += " AND timestamp >= ?"
- params.append(start_time)
-
- if end_time:
- query += " AND timestamp <= ?"
- params.append(end_time)
-
- query += " ORDER BY timestamp DESC"
-
- if limit:
- query += " LIMIT ?"
- params.append(limit)
-
- cursor.execute(query, params)
- rows = cursor.fetchall()
-
- return [dict(row) for row in rows]
-
- except Exception as e:
- logger.error(f"获取价格数据失败: {e}")
- return []
-
- def get_trading_events(self,
- symbol: Optional[str] = None,
- event_type: Optional[str] = None,
- start_time: Optional[float] = None,
- end_time: Optional[float] = None,
- limit: Optional[int] = None) -> List[Dict]:
- """
- 获取交易事件
-
- Args:
- symbol: 交易对符号
- event_type: 事件类型
- start_time: 开始时间戳
- end_time: 结束时间戳
- limit: 限制返回条数
-
- Returns:
- 交易事件列表
- """
- try:
- cursor = self.connection.cursor()
-
- query = "SELECT * FROM trading_events WHERE 1=1"
- params = []
-
- if symbol:
- query += " AND symbol = ?"
- params.append(symbol)
-
- if event_type:
- query += " AND event_type = ?"
- params.append(event_type)
-
- if start_time:
- query += " AND timestamp >= ?"
- params.append(start_time)
-
- if end_time:
- query += " AND timestamp <= ?"
- params.append(end_time)
-
- query += " ORDER BY timestamp DESC"
-
- if limit:
- query += " LIMIT ?"
- params.append(limit)
-
- cursor.execute(query, params)
- rows = cursor.fetchall()
-
- return [dict(row) for row in rows]
-
- except Exception as e:
- logger.error(f"获取交易事件失败: {e}")
- return []
-
- def get_statistics(self, symbol: Optional[str] = None) -> Dict[str, Any]:
- """
- 获取统计信息
-
- Args:
- symbol: 交易对符号
-
- Returns:
- 统计信息字典
- """
- try:
- cursor = self.connection.cursor()
- stats = {}
-
- # 价格数据统计
- query = "SELECT COUNT(*) as count FROM price_data"
- params = []
- if symbol:
- query += " WHERE symbol = ?"
- params.append(symbol)
-
- cursor.execute(query, params)
- stats['price_data_count'] = cursor.fetchone()[0]
-
- # 交易事件统计
- query = "SELECT COUNT(*) as count FROM trading_events"
- params = []
- if symbol:
- query += " WHERE symbol = ?"
- params.append(symbol)
-
- cursor.execute(query, params)
- stats['trading_events_count'] = cursor.fetchone()[0]
-
- # 按事件类型统计
- query = "SELECT event_type, COUNT(*) as count FROM trading_events"
- params = []
- if symbol:
- query += " WHERE symbol = ?"
- params.append(symbol)
- query += " GROUP BY event_type"
-
- cursor.execute(query, params)
- event_stats = {}
- for row in cursor.fetchall():
- event_stats[row[0]] = row[1]
- stats['events_by_type'] = event_stats
-
- return stats
-
- except Exception as e:
- logger.error(f"获取统计信息失败: {e}")
- return {}
-
- def close(self):
- """关闭数据库连接"""
- if self.connection:
- self.connection.close()
- logger.info("数据库连接已关闭")
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
|