data_processing.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import json
  2. import time
  3. import numpy as np
  4. import pandas as pd
  5. import queue
  6. import threading
  7. from collections import deque
  8. from sklearn.model_selection import train_test_split
  9. from sklearn.ensemble import RandomForestClassifier
  10. from sklearn.metrics import classification_report
  11. from sklearn.utils import resample
  12. from logger_config import logger
  13. # 假设我们有一个数据流,订单簿和成交数据
  14. order_book_snapshots = deque(maxlen=100) # 存储过去100ms的订单簿快照
  15. trade_data = deque(maxlen=100) # 存储过去100ms的成交数据
  16. # 数据积累的阈值
  17. DATA_THRESHOLD = 20
  18. model = RandomForestClassifier(n_estimators=100) # 使用随机森林模型
  19. model_trained = False # 标记模型是否已经被训练
  20. messages = queue.Queue() # 创建一个线程安全队列
  21. stop_event = threading.Event()
  22. def on_message_trade(_ws, message):
  23. global trade_data
  24. json_message = json.loads(message)
  25. trade = {
  26. 'price': float(json_message['data']['p']),
  27. 'qty': float(json_message['data']['q']),
  28. 'timestamp': pd.to_datetime(json_message['data']['T'], unit='ms'),
  29. 'side': 'sell' if json_message['data']['m'] else 'buy'
  30. }
  31. trade_data.append(trade)
  32. check_and_train_model()
  33. predict_market_direction()
  34. def on_message_depth(_ws, message):
  35. global order_book_snapshots
  36. json_message = json.loads(message)
  37. bids = json_message['data']['b'][:10] # Top 10 bids
  38. asks = json_message['data']['a'][:10] # Top 10 asks
  39. timestamp = pd.to_datetime(json_message['data']['E'], unit='ms')
  40. order_book_snapshots.append({
  41. 'bids': bids,
  42. 'asks': asks,
  43. 'timestamp': timestamp
  44. })
  45. check_and_train_model()
  46. predict_market_direction()
  47. def extract_features(order_book, trade):
  48. # 计算买卖盘差距(spread)
  49. best_bid = float(order_book['bids'][0][0])
  50. best_ask = float(order_book['asks'][0][0])
  51. spread = best_ask - best_bid
  52. # 计算买卖盘深度
  53. bid_depth = sum(float(bid[1]) for bid in order_book['bids'])
  54. ask_depth = sum(float(ask[1]) for ask in order_book['asks'])
  55. # 计算成交量和方向
  56. trade_volume = trade['qty']
  57. trade_side = 1 if trade['side'] == 'buy' else -1
  58. # 计算买卖盘数量
  59. bid_count = len(order_book['bids'])
  60. ask_count = len(order_book['asks'])
  61. # 计算时间特征
  62. timestamp = trade['timestamp'].timestamp()
  63. features = {
  64. 'spread': spread,
  65. 'bid_depth': bid_depth,
  66. 'ask_depth': ask_depth,
  67. 'trade_volume': trade_volume,
  68. 'trade_side': trade_side,
  69. 'bid_count': bid_count,
  70. 'ask_count': ask_count,
  71. 'timestamp': timestamp
  72. }
  73. return features
  74. def prepare_training_data():
  75. # 提取特征和标签
  76. X = []
  77. y = []
  78. for i in range(len(order_book_snapshots) - 1):
  79. if i + 1 >= len(order_book_snapshots) or i >= len(trade_data):
  80. break
  81. current_order_book = order_book_snapshots[i]
  82. current_trade = trade_data[i]
  83. future_order_book = order_book_snapshots[i + 1]
  84. # 提取当前的特征
  85. features = extract_features(current_order_book, current_trade)
  86. X.append(list(features.values()))
  87. # 生成标签
  88. current_price = float(current_order_book['bids'][0][0])
  89. future_price = float(future_order_book['bids'][0][0])
  90. label = generate_label(current_price, future_price)
  91. y.append(label)
  92. # 将特征和标签转换为NumPy数组
  93. X_train = np.array(X)
  94. y_train = np.array(y)
  95. return X_train, y_train
  96. def generate_label(current_price, future_price):
  97. return 1 if future_price > current_price else 0
  98. def balance_data(X, y):
  99. # 将数据转换为 DataFrame,便于处理
  100. df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count', 'ask_count', 'timestamp'])
  101. df['label'] = y
  102. # 分别获取两类数据
  103. df_majority = df[df.label == 0]
  104. df_minority = df[df.label == 1]
  105. if len(df_minority) == 0 or len(df_majority) == 0:
  106. return X, y # 如果某一类数据为空,返回原数据
  107. # 使用下采样或上采样来平衡数据
  108. if len(df_majority) > len(df_minority):
  109. df_majority_downsampled = resample(df_majority, replace=False, n_samples=len(df_minority), random_state=42)
  110. df_balanced = pd.concat([df_majority_downsampled, df_minority])
  111. else:
  112. df_minority_upsampled = resample(df_minority, replace=True, n_samples=len(df_majority), random_state=42)
  113. df_balanced = pd.concat([df_majority, df_minority_upsampled])
  114. # 提取平衡后的特征和标签
  115. X_balanced = df_balanced.drop('label', axis=1).values
  116. y_balanced = df_balanced['label'].values
  117. return X_balanced, y_balanced
  118. def check_and_train_model():
  119. global model_trained
  120. if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
  121. X_train, y_train = prepare_training_data()
  122. if len(X_train) > 0 and len(y_train) > 0:
  123. X_train, y_train = balance_data(X_train, y_train) # 平衡数据
  124. X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
  125. model.fit(X_train, y_train)
  126. model_trained = True # 标记模型已经被训练
  127. logger.info("Model trained with %d samples", len(X_train))
  128. y_pred = model.predict(X_test)
  129. logger.info("\n" + classification_report(y_test, y_pred, zero_division=0))
  130. else:
  131. logger.info("Insufficient data to train the model")
  132. else:
  133. logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots), len(trade_data))
  134. def predict_market_direction():
  135. global model_trained
  136. if len(order_book_snapshots) == 0 or len(trade_data) == 0:
  137. return
  138. if not model_trained:
  139. return
  140. features = extract_features(order_book_snapshots[-1], trade_data[-1])
  141. if features is not None:
  142. feature_vector = np.array([list(features.values())])
  143. prediction = model.predict(feature_vector)
  144. logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
  145. show_message(prediction[0])
  146. def show_message(market_direction):
  147. global order_book_snapshots, trade_data
  148. if len(order_book_snapshots) > 0 and len(trade_data) > 0:
  149. # 获取最新的订单簿数据和成交数据
  150. latest_order_book = order_book_snapshots[-1]
  151. latest_trade = trade_data[-1]
  152. # 提取asks和bids数据
  153. asks = [[float(price), 1 if market_direction == 1 else 0] for price, qty in latest_order_book['asks']]
  154. bids = [[float(price), 1 if market_direction == 0 else 0] for price, qty in latest_order_book['bids']]
  155. # 排序asks和bids数据
  156. asks_sorted = sorted(asks, key=lambda x: x[0])
  157. bids_sorted = sorted(bids, key=lambda x: x[0], reverse=True)
  158. last_price = latest_trade['price']
  159. # last_qty = latest_trade['qty']
  160. last_qty = 0
  161. side = latest_trade['side']
  162. # 生成数据字典
  163. data = {
  164. "asks": asks_sorted,
  165. "bids": bids_sorted,
  166. "last_price": last_price,
  167. "last_qty": last_qty,
  168. "side": side,
  169. "time": int(time.time() * 1000)
  170. }
  171. # 将数据放入消息队列
  172. messages.put(data)