database.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 数据库模块
  5. 用于记录价格数据和交易事件,支持后续的数据分析和可视化
  6. """
  7. import sqlite3
  8. import logging
  9. import os
  10. from datetime import datetime
  11. from typing import Optional, List, Dict, Any
  12. import json
  13. logger = logging.getLogger("database")
  14. class TradingDatabase:
  15. """交易数据库类,负责记录价格数据和交易事件"""
  16. def __init__(self, db_path: str = None):
  17. """
  18. 初始化数据库连接
  19. Args:
  20. db_path: 数据库文件路径,如果为None则使用默认路径
  21. """
  22. if db_path is None:
  23. # 使用项目根目录下的data文件夹
  24. script_dir = os.path.dirname(os.path.abspath(__file__))
  25. project_root = os.path.dirname(os.path.dirname(script_dir))
  26. data_dir = os.path.join(project_root, "data")
  27. # 确保data目录存在
  28. os.makedirs(data_dir, exist_ok=True)
  29. # 使用日期作为数据库文件名
  30. today = datetime.now().strftime("%Y%m%d")
  31. db_path = os.path.join(data_dir, f"trading_data_{today}.db")
  32. self.db_path = db_path
  33. self.connection = None
  34. self._init_database()
  35. def _init_database(self):
  36. """初始化数据库连接和表结构"""
  37. try:
  38. self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
  39. self.connection.row_factory = sqlite3.Row # 使结果可以通过列名访问
  40. self._create_tables()
  41. logger.info(f"数据库初始化成功: {self.db_path}")
  42. except Exception as e:
  43. logger.error(f"数据库初始化失败: {e}")
  44. raise
  45. def _create_tables(self):
  46. """创建数据库表结构"""
  47. cursor = self.connection.cursor()
  48. # 价格数据表
  49. cursor.execute("""
  50. CREATE TABLE IF NOT EXISTS price_data (
  51. id INTEGER PRIMARY KEY AUTOINCREMENT,
  52. timestamp REAL NOT NULL,
  53. datetime_str TEXT NOT NULL,
  54. symbol TEXT NOT NULL,
  55. lighter_price REAL,
  56. binance_price REAL,
  57. spread_bps REAL,
  58. lighter_bid REAL,
  59. lighter_ask REAL,
  60. lighter_bid_size REAL,
  61. lighter_ask_size REAL,
  62. binance_volume REAL,
  63. raw_data TEXT,
  64. created_at DATETIME DEFAULT CURRENT_TIMESTAMP
  65. )
  66. """)
  67. # 交易事件表
  68. cursor.execute("""
  69. CREATE TABLE IF NOT EXISTS trading_events (
  70. id INTEGER PRIMARY KEY AUTOINCREMENT,
  71. timestamp REAL NOT NULL,
  72. datetime_str TEXT NOT NULL,
  73. symbol TEXT NOT NULL,
  74. event_type TEXT NOT NULL, -- 'open_long', 'open_short', 'close_long', 'close_short'
  75. price REAL NOT NULL,
  76. quantity REAL NOT NULL,
  77. side TEXT NOT NULL, -- 'long', 'short'
  78. strategy_state TEXT,
  79. spread_bps REAL,
  80. lighter_price REAL,
  81. binance_price REAL,
  82. order_id TEXT,
  83. tx_hash TEXT,
  84. success BOOLEAN,
  85. error_message TEXT,
  86. metadata TEXT, -- JSON格式的额外信息
  87. created_at DATETIME DEFAULT CURRENT_TIMESTAMP
  88. )
  89. """)
  90. # 创建索引以提高查询性能
  91. cursor.execute("CREATE INDEX IF NOT EXISTS idx_price_timestamp ON price_data(timestamp)")
  92. cursor.execute("CREATE INDEX IF NOT EXISTS idx_price_symbol ON price_data(symbol)")
  93. cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_timestamp ON trading_events(timestamp)")
  94. cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_symbol ON trading_events(symbol)")
  95. cursor.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON trading_events(event_type)")
  96. self.connection.commit()
  97. logger.info("数据库表结构创建完成")
  98. def record_price_data(self,
  99. symbol: str,
  100. lighter_price: Optional[float] = None,
  101. binance_price: Optional[float] = None,
  102. spread_bps: Optional[float] = None,
  103. lighter_bid: Optional[float] = None,
  104. lighter_ask: Optional[float] = None,
  105. lighter_bid_size: Optional[float] = None,
  106. lighter_ask_size: Optional[float] = None,
  107. binance_volume: Optional[float] = None,
  108. raw_data: Optional[Dict] = None):
  109. """
  110. 记录价格数据
  111. Args:
  112. symbol: 交易对符号
  113. lighter_price: Lighter价格
  114. binance_price: Binance价格
  115. spread_bps: 价差(基点)
  116. lighter_bid: Lighter买价
  117. lighter_ask: Lighter卖价
  118. lighter_bid_size: Lighter买量
  119. lighter_ask_size: Lighter卖量
  120. binance_volume: Binance成交量
  121. raw_data: 原始数据(字典格式)
  122. """
  123. try:
  124. timestamp = datetime.now().timestamp()
  125. datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
  126. cursor = self.connection.cursor()
  127. cursor.execute("""
  128. INSERT INTO price_data (
  129. timestamp, datetime_str, symbol, lighter_price, binance_price,
  130. spread_bps, lighter_bid, lighter_ask, lighter_bid_size,
  131. lighter_ask_size, binance_volume, raw_data
  132. ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  133. """, (
  134. timestamp, datetime_str, symbol, lighter_price, binance_price,
  135. spread_bps, lighter_bid, lighter_ask, lighter_bid_size,
  136. lighter_ask_size, binance_volume,
  137. json.dumps(raw_data) if raw_data else None
  138. ))
  139. self.connection.commit()
  140. except Exception as e:
  141. logger.error(f"记录价格数据失败: {e}")
  142. def record_trading_event(self,
  143. symbol: str,
  144. event_type: str,
  145. price: float,
  146. quantity: float,
  147. side: str,
  148. strategy_state: Optional[str] = None,
  149. spread_bps: Optional[float] = None,
  150. lighter_price: Optional[float] = None,
  151. binance_price: Optional[float] = None,
  152. order_id: Optional[str] = None,
  153. tx_hash: Optional[str] = None,
  154. success: Optional[bool] = None,
  155. error_message: Optional[str] = None,
  156. metadata: Optional[Dict] = None):
  157. """
  158. 记录交易事件
  159. Args:
  160. symbol: 交易对符号
  161. event_type: 事件类型 ('open_long', 'open_short', 'close_long', 'close_short')
  162. price: 交易价格
  163. quantity: 交易数量
  164. side: 交易方向 ('long', 'short')
  165. strategy_state: 策略状态
  166. spread_bps: 当时的价差
  167. lighter_price: 当时的Lighter价格
  168. binance_price: 当时的Binance价格
  169. order_id: 订单ID
  170. tx_hash: 交易哈希
  171. success: 是否成功
  172. error_message: 错误信息
  173. metadata: 额外的元数据
  174. """
  175. try:
  176. timestamp = datetime.now().timestamp()
  177. datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
  178. cursor = self.connection.cursor()
  179. cursor.execute("""
  180. INSERT INTO trading_events (
  181. timestamp, datetime_str, symbol, event_type, price, quantity,
  182. side, strategy_state, spread_bps, lighter_price, binance_price,
  183. order_id, tx_hash, success, error_message, metadata
  184. ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  185. """, (
  186. timestamp, datetime_str, symbol, event_type, price, quantity,
  187. side, strategy_state, spread_bps, lighter_price, binance_price,
  188. order_id, tx_hash, success, error_message,
  189. json.dumps(metadata) if metadata else None
  190. ))
  191. self.connection.commit()
  192. logger.info(f"记录交易事件: {event_type} {side} {quantity}@{price}")
  193. except Exception as e:
  194. logger.error(f"记录交易事件失败: {e}")
  195. def get_price_data(self,
  196. symbol: Optional[str] = None,
  197. start_time: Optional[float] = None,
  198. end_time: Optional[float] = None,
  199. limit: Optional[int] = None) -> List[Dict]:
  200. """
  201. 获取价格数据
  202. Args:
  203. symbol: 交易对符号
  204. start_time: 开始时间戳
  205. end_time: 结束时间戳
  206. limit: 限制返回条数
  207. Returns:
  208. 价格数据列表
  209. """
  210. try:
  211. cursor = self.connection.cursor()
  212. query = "SELECT * FROM price_data WHERE 1=1"
  213. params = []
  214. if symbol:
  215. query += " AND symbol = ?"
  216. params.append(symbol)
  217. if start_time:
  218. query += " AND timestamp >= ?"
  219. params.append(start_time)
  220. if end_time:
  221. query += " AND timestamp <= ?"
  222. params.append(end_time)
  223. query += " ORDER BY timestamp DESC"
  224. if limit:
  225. query += " LIMIT ?"
  226. params.append(limit)
  227. cursor.execute(query, params)
  228. rows = cursor.fetchall()
  229. return [dict(row) for row in rows]
  230. except Exception as e:
  231. logger.error(f"获取价格数据失败: {e}")
  232. return []
  233. def get_trading_events(self,
  234. symbol: Optional[str] = None,
  235. event_type: Optional[str] = None,
  236. start_time: Optional[float] = None,
  237. end_time: Optional[float] = None,
  238. limit: Optional[int] = None) -> List[Dict]:
  239. """
  240. 获取交易事件
  241. Args:
  242. symbol: 交易对符号
  243. event_type: 事件类型
  244. start_time: 开始时间戳
  245. end_time: 结束时间戳
  246. limit: 限制返回条数
  247. Returns:
  248. 交易事件列表
  249. """
  250. try:
  251. cursor = self.connection.cursor()
  252. query = "SELECT * FROM trading_events WHERE 1=1"
  253. params = []
  254. if symbol:
  255. query += " AND symbol = ?"
  256. params.append(symbol)
  257. if event_type:
  258. query += " AND event_type = ?"
  259. params.append(event_type)
  260. if start_time:
  261. query += " AND timestamp >= ?"
  262. params.append(start_time)
  263. if end_time:
  264. query += " AND timestamp <= ?"
  265. params.append(end_time)
  266. query += " ORDER BY timestamp DESC"
  267. if limit:
  268. query += " LIMIT ?"
  269. params.append(limit)
  270. cursor.execute(query, params)
  271. rows = cursor.fetchall()
  272. return [dict(row) for row in rows]
  273. except Exception as e:
  274. logger.error(f"获取交易事件失败: {e}")
  275. return []
  276. def get_statistics(self, symbol: Optional[str] = None) -> Dict[str, Any]:
  277. """
  278. 获取统计信息
  279. Args:
  280. symbol: 交易对符号
  281. Returns:
  282. 统计信息字典
  283. """
  284. try:
  285. cursor = self.connection.cursor()
  286. stats = {}
  287. # 价格数据统计
  288. query = "SELECT COUNT(*) as count FROM price_data"
  289. params = []
  290. if symbol:
  291. query += " WHERE symbol = ?"
  292. params.append(symbol)
  293. cursor.execute(query, params)
  294. stats['price_data_count'] = cursor.fetchone()[0]
  295. # 交易事件统计
  296. query = "SELECT COUNT(*) as count FROM trading_events"
  297. params = []
  298. if symbol:
  299. query += " WHERE symbol = ?"
  300. params.append(symbol)
  301. cursor.execute(query, params)
  302. stats['trading_events_count'] = cursor.fetchone()[0]
  303. # 按事件类型统计
  304. query = "SELECT event_type, COUNT(*) as count FROM trading_events"
  305. params = []
  306. if symbol:
  307. query += " WHERE symbol = ?"
  308. params.append(symbol)
  309. query += " GROUP BY event_type"
  310. cursor.execute(query, params)
  311. event_stats = {}
  312. for row in cursor.fetchall():
  313. event_stats[row[0]] = row[1]
  314. stats['events_by_type'] = event_stats
  315. return stats
  316. except Exception as e:
  317. logger.error(f"获取统计信息失败: {e}")
  318. return {}
  319. def close(self):
  320. """关闭数据库连接"""
  321. if self.connection:
  322. self.connection.close()
  323. logger.info("数据库连接已关闭")
  324. def __enter__(self):
  325. return self
  326. def __exit__(self, exc_type, exc_val, exc_tb):
  327. self.close()