data_processing.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import json
  2. import time
  3. import numpy as np
  4. import pandas as pd
  5. import queue
  6. import threading
  7. from collections import deque
  8. from sklearn.model_selection import train_test_split
  9. from sklearn.ensemble import RandomForestClassifier
  10. from sklearn.metrics import classification_report
  11. from sklearn.utils import resample
  12. from logger_config import logger
  13. # 假设我们有一个数据流,订单簿和成交数据
  14. order_book_snapshots = deque(maxlen=100) # 存储过去100ms的订单簿快照
  15. trade_data = deque(maxlen=100) # 存储过去100ms的成交数据
  16. prediction_window = deque(maxlen=10) # 用于平滑预测结果的滑动窗口
  17. # 数据积累的阈值
  18. DATA_THRESHOLD = 20
  19. model = RandomForestClassifier(n_estimators=100) # 使用随机森林模型
  20. model_trained = False # 标记模型是否已经被训练
  21. messages = queue.Queue() # 创建一个线程安全队列
  22. stop_event = threading.Event()
  23. def on_message_trade(_ws, message):
  24. global trade_data
  25. json_message = json.loads(message)
  26. trade = {
  27. 'price': float(json_message['data']['p']),
  28. 'qty': float(json_message['data']['q']),
  29. 'timestamp': pd.to_datetime(json_message['data']['T'], unit='ms'),
  30. 'side': 'sell' if json_message['data']['m'] else 'buy'
  31. }
  32. trade_data.append(trade)
  33. check_and_train_model()
  34. predict_market_direction()
  35. def on_message_depth(_ws, message):
  36. global order_book_snapshots
  37. json_message = json.loads(message)
  38. bids = json_message['data']['b'][:10] # Top 10 bids
  39. asks = json_message['data']['a'][:10] # Top 10 asks
  40. timestamp = pd.to_datetime(json_message['data']['E'], unit='ms')
  41. order_book_snapshots.append({
  42. 'bids': bids,
  43. 'asks': asks,
  44. 'timestamp': timestamp
  45. })
  46. check_and_train_model()
  47. predict_market_direction()
  48. def extract_features(order_book, trade):
  49. # 计算买卖盘差距(spread)
  50. best_bid = float(order_book['bids'][0][0])
  51. best_ask = float(order_book['asks'][0][0])
  52. spread = best_ask - best_bid
  53. # 计算买卖盘深度
  54. bid_depth = sum(float(bid[1]) for bid in order_book['bids'])
  55. ask_depth = sum(float(ask[1]) for ask in order_book['asks'])
  56. # 计算成交量和方向
  57. trade_volume = trade['qty']
  58. trade_side = 1 if trade['side'] == 'buy' else -1
  59. # 计算买卖盘数量
  60. bid_count = len(order_book['bids'])
  61. ask_count = len(order_book['asks'])
  62. # 计算时间特征
  63. timestamp = trade['timestamp'].timestamp()
  64. features = {
  65. 'spread': spread,
  66. 'bid_depth': bid_depth,
  67. 'ask_depth': ask_depth,
  68. 'trade_volume': trade_volume,
  69. 'trade_side': trade_side,
  70. 'bid_count': bid_count,
  71. 'ask_count': ask_count,
  72. 'timestamp': timestamp
  73. }
  74. return features
  75. def prepare_training_data():
  76. # 提取特征和标签
  77. X = []
  78. y = []
  79. for i in range(len(order_book_snapshots) - 1):
  80. if i + 1 >= len(order_book_snapshots) or i >= len(trade_data):
  81. break
  82. current_order_book = order_book_snapshots[i]
  83. current_trade = trade_data[i]
  84. future_order_book = order_book_snapshots[i + 1]
  85. # 提取当前的特征
  86. features = extract_features(current_order_book, current_trade)
  87. X.append(list(features.values()))
  88. # 生成标签
  89. current_price = float(current_order_book['bids'][0][0])
  90. future_price = float(future_order_book['bids'][0][0])
  91. label = generate_label(current_price, future_price)
  92. y.append(label)
  93. # 将特征和标签转换为NumPy数组
  94. X_train = np.array(X)
  95. y_train = np.array(y)
  96. return X_train, y_train
  97. def generate_label(current_price, future_price):
  98. return 1 if future_price > current_price else 0
  99. def balance_data(X, y):
  100. # 将数据转换为 DataFrame,便于处理
  101. df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count',
  102. 'ask_count', 'timestamp'])
  103. df['label'] = y
  104. # 分别获取两类数据
  105. df_majority = df[df.label == 0]
  106. df_minority = df[df.label == 1]
  107. if len(df_minority) == 0 or len(df_majority) == 0:
  108. return X, y # 如果某一类数据为空,返回原数据
  109. # 使用下采样或上采样来平衡数据
  110. if len(df_majority) > len(df_minority):
  111. df_majority_downsampled = resample(df_majority, replace=False, n_samples=len(df_minority), random_state=42)
  112. df_balanced = pd.concat([df_majority_downsampled, df_minority])
  113. else:
  114. df_minority_upsampled = resample(df_minority, replace=True, n_samples=len(df_majority), random_state=42)
  115. df_balanced = pd.concat([df_majority, df_minority_upsampled])
  116. # 提取平衡后的特征和标签
  117. X_balanced = df_balanced.drop('label', axis=1).values
  118. y_balanced = df_balanced['label'].values
  119. return X_balanced, y_balanced
  120. def check_and_train_model():
  121. global model_trained
  122. if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
  123. X_train, y_train = prepare_training_data()
  124. if len(X_train) > 0 and len(y_train) > 0:
  125. X_train, y_train = balance_data(X_train, y_train) # 平衡数据
  126. X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
  127. model.fit(X_train, y_train)
  128. model_trained = True # 标记模型已经被训练
  129. logger.info("Model trained with %d samples", len(X_train))
  130. y_pred = model.predict(X_test)
  131. logger.info("\n" + classification_report(y_test, y_pred, zero_division=0))
  132. else:
  133. logger.info("Insufficient data to train the model")
  134. else:
  135. logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots),
  136. len(trade_data))
  137. def predict_market_direction():
  138. global model_trained, prediction_window
  139. if len(order_book_snapshots) == 0 or len(trade_data) == 0:
  140. return
  141. if not model_trained:
  142. return
  143. features = extract_features(order_book_snapshots[-1], trade_data[-1])
  144. if features is not None:
  145. feature_vector = np.array([list(features.values())])
  146. prediction = model.predict(feature_vector)[0]
  147. # 将预测结果添加到滑动窗口中
  148. prediction_window.append(prediction)
  149. # 根据滑动窗口中的多数结果确定当前的预测方向
  150. if len(prediction_window) == prediction_window.maxlen:
  151. smoothed_prediction = 1 if sum(prediction_window) > len(prediction_window) / 2 else 0
  152. logger.info("Predicted Market Direction (smoothed): %s", "Up" if smoothed_prediction == 1 else "Down")
  153. show_message(smoothed_prediction)
  154. def show_message(market_direction):
  155. global order_book_snapshots, trade_data
  156. if len(order_book_snapshots) > 0 and len(trade_data) > 0:
  157. # 获取最新的订单簿数据和成交数据
  158. latest_order_book = order_book_snapshots[-1]
  159. latest_trade = trade_data[-1]
  160. # 提取asks和bids数据
  161. asks = [[float(price), 0] for price, qty in latest_order_book['asks']]
  162. bids = [[float(price), 0] for price, qty in latest_order_book['bids']]
  163. # 排序asks和bids数据
  164. asks_sorted = sorted(asks, key=lambda x: x[0])
  165. asks_sorted[-1] = [asks_sorted[-1][0], 1 if market_direction == 1 else 0]
  166. bids_sorted = sorted(bids, key=lambda x: x[0], reverse=True)
  167. bids_sorted[-1] = [bids_sorted[-1][0], 0 if market_direction == 1 else 1]
  168. last_price = latest_trade['price']
  169. # last_qty = latest_trade['qty']
  170. last_qty = 0
  171. side = latest_trade['side']
  172. # 生成数据字典
  173. data = {
  174. "asks": asks_sorted,
  175. "bids": bids_sorted,
  176. "last_price": last_price,
  177. "last_qty": last_qty,
  178. "side": side,
  179. "time": int(time.time() * 1000)
  180. }
  181. # 将数据放入消息队列
  182. messages.put(data)