data_processing.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import json
  2. import numpy as np
  3. import pandas as pd
  4. import time
  5. import queue
  6. import threading
  7. from datetime import datetime
  8. from logger_config import logger
  9. from collections import deque
  10. from sklearn.linear_model import LogisticRegression
  11. # 假设我们有一个数据流,订单簿和成交数据
  12. order_book_snapshots = deque(maxlen=10000) # 存储过去100ms的订单簿快照
  13. trade_data = deque(maxlen=10000) # 存储过去100ms的成交数据
  14. # 数据积累的阈值
  15. DATA_THRESHOLD = 10
  16. model = LogisticRegression() # 初始化模型
  17. model_trained = False # 标记模型是否已经被训练
  18. messages = queue.Queue() # 创建一个线程安全队列
  19. stop_event = threading.Event()
  20. def on_message_trade(_ws, message):
  21. global trade_data
  22. json_message = json.loads(message)
  23. trade = {
  24. 'price': float(json_message['data']['p']),
  25. 'qty': float(json_message['data']['q']),
  26. 'timestamp': pd.to_datetime(json_message['data']['T'], unit='ms'),
  27. 'side': 'sell' if json_message['data']['m'] else 'buy'
  28. }
  29. trade_data.append(trade)
  30. predict_market_direction()
  31. check_and_train_model()
  32. def on_message_depth(_ws, message):
  33. global order_book_snapshots
  34. json_message = json.loads(message)
  35. bids = json_message['data']['b'][:10] # Top 10 bids
  36. asks = json_message['data']['a'][:10] # Top 10 asks
  37. timestamp = pd.to_datetime(json_message['data']['E'], unit='ms')
  38. order_book_snapshots.append({
  39. 'bids': bids,
  40. 'asks': asks,
  41. 'timestamp': timestamp
  42. })
  43. predict_market_direction()
  44. check_and_train_model()
  45. def extract_features(order_book, trade):
  46. # 计算买卖盘差距(spread)
  47. best_bid = float(order_book['bids'][0][0])
  48. best_ask = float(order_book['asks'][0][0])
  49. spread = best_ask - best_bid
  50. # 计算买卖盘深度
  51. bid_depth = sum(float(bid[1]) for bid in order_book['bids'])
  52. ask_depth = sum(float(ask[1]) for ask in order_book['asks'])
  53. # 计算成交量和方向
  54. trade_volume = trade['qty']
  55. trade_side = 1 if trade['side'] == 'buy' else -1
  56. features = {
  57. 'spread': spread,
  58. 'bid_depth': bid_depth,
  59. 'ask_depth': ask_depth,
  60. 'trade_volume': trade_volume,
  61. 'trade_side': trade_side
  62. }
  63. return features
  64. def prepare_training_data():
  65. # 提取特征和标签
  66. X = []
  67. y = []
  68. for i in range(len(order_book_snapshots) - 1):
  69. current_order_book = order_book_snapshots[i]
  70. current_trade = trade_data[i]
  71. future_order_book = order_book_snapshots[i + 1]
  72. # 提取当前的特征
  73. features = extract_features(current_order_book, current_trade)
  74. X.append(list(features.values()))
  75. # 生成标签
  76. current_price = float(current_order_book['bids'][0][0])
  77. future_price = float(future_order_book['bids'][0][0])
  78. label = generate_label(current_price, future_price)
  79. y.append(label)
  80. # 将特征和标签转换为NumPy数组
  81. X_train = np.array(X)
  82. y_train = np.array(y)
  83. return X_train, y_train
  84. def generate_label(current_price, future_price):
  85. return 1 if future_price > current_price else 0
  86. def check_and_train_model():
  87. global model_trained
  88. if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
  89. X_train, y_train = prepare_training_data()
  90. model.fit(X_train, y_train)
  91. model_trained = True # 标记模型已经被训练
  92. logger.info("Model trained with %d samples", len(X_train))
  93. def predict_market_direction():
  94. global model_trained
  95. if len(order_book_snapshots) == 0 or len(trade_data) == 0:
  96. logger.info("Not enough data to make a prediction")
  97. return
  98. if not model_trained:
  99. logger.info("Model is not trained yet")
  100. return
  101. features = extract_features(order_book_snapshots[-1], trade_data[-1])
  102. if features is not None:
  103. feature_vector = np.array([list(features.values())])
  104. prediction = model.predict(feature_vector)
  105. logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")