data_processing.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import json
  2. import time
  3. import numpy as np
  4. import pandas as pd
  5. import queue
  6. import threading
  7. from logger_config import logger
  8. from collections import deque
  9. from sklearn.linear_model import LogisticRegression
  10. # 假设我们有一个数据流,订单簿和成交数据
  11. order_book_snapshots = deque(maxlen=10000) # 存储过去100ms的订单簿快照
  12. trade_data = deque(maxlen=10000) # 存储过去100ms的成交数据
  13. # 数据积累的阈值
  14. DATA_THRESHOLD = 20
  15. model = LogisticRegression() # 初始化模型
  16. model_trained = False # 标记模型是否已经被训练
  17. messages = queue.Queue() # 创建一个线程安全队列
  18. stop_event = threading.Event()
  19. def on_message_trade(_ws, message):
  20. global trade_data
  21. json_message = json.loads(message)
  22. trade = {
  23. 'price': float(json_message['data']['p']),
  24. 'qty': float(json_message['data']['q']),
  25. 'timestamp': pd.to_datetime(json_message['data']['T'], unit='ms'),
  26. 'side': 'sell' if json_message['data']['m'] else 'buy'
  27. }
  28. trade_data.append(trade)
  29. check_and_train_model()
  30. predict_market_direction()
  31. def on_message_depth(_ws, message):
  32. global order_book_snapshots
  33. json_message = json.loads(message)
  34. bids = json_message['data']['b'][:10] # Top 10 bids
  35. asks = json_message['data']['a'][:10] # Top 10 asks
  36. timestamp = pd.to_datetime(json_message['data']['E'], unit='ms')
  37. order_book_snapshots.append({
  38. 'bids': bids,
  39. 'asks': asks,
  40. 'timestamp': timestamp
  41. })
  42. check_and_train_model()
  43. predict_market_direction()
  44. def extract_features(order_book, trade):
  45. # 计算买卖盘差距(spread)
  46. best_bid = float(order_book['bids'][0][0])
  47. best_ask = float(order_book['asks'][0][0])
  48. spread = best_ask - best_bid
  49. # 计算买卖盘深度
  50. bid_depth = sum(float(bid[1]) for bid in order_book['bids'])
  51. ask_depth = sum(float(ask[1]) for ask in order_book['asks'])
  52. # 计算成交量和方向
  53. trade_volume = trade['qty']
  54. trade_side = 1 if trade['side'] == 'buy' else -1
  55. features = {
  56. 'spread': spread,
  57. 'bid_depth': bid_depth,
  58. 'ask_depth': ask_depth,
  59. 'trade_volume': trade_volume,
  60. 'trade_side': trade_side
  61. }
  62. return features
  63. def prepare_training_data():
  64. # 提取特征和标签
  65. X = []
  66. y = []
  67. for i in range(len(order_book_snapshots) - 1):
  68. if i + 1 >= len(order_book_snapshots) or i >= len(trade_data):
  69. break
  70. current_order_book = order_book_snapshots[i]
  71. current_trade = trade_data[i]
  72. future_order_book = order_book_snapshots[i + 1]
  73. # 提取当前的特征
  74. features = extract_features(current_order_book, current_trade)
  75. X.append(list(features.values()))
  76. # 生成标签
  77. current_price = float(current_order_book['bids'][0][0])
  78. future_price = float(future_order_book['bids'][0][0])
  79. label = generate_label(current_price, future_price)
  80. y.append(label)
  81. # 将特征和标签转换为NumPy数组
  82. X_train = np.array(X)
  83. y_train = np.array(y)
  84. return X_train, y_train
  85. def generate_label(current_price, future_price):
  86. return 1 if future_price > current_price else 0
  87. def check_and_train_model():
  88. global model_trained
  89. if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
  90. X_train, y_train = prepare_training_data()
  91. if len(X_train) > 0 and len(y_train) > 0:
  92. model.fit(X_train, y_train)
  93. model_trained = True # 标记模型已经被训练
  94. # logger.info("Model trained with %d samples", len(X_train))
  95. else:
  96. logger.info("Insufficient data to train the model")
  97. else:
  98. logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots), len(trade_data))
  99. def predict_market_direction():
  100. global model_trained
  101. if len(order_book_snapshots) == 0 or len(trade_data) == 0:
  102. # logger.info("Not enough data to make a prediction")
  103. return
  104. if not model_trained:
  105. # logger.info("Model is not trained yet")
  106. return
  107. features = extract_features(order_book_snapshots[-1], trade_data[-1])
  108. if features is not None:
  109. feature_vector = np.array([list(features.values())])
  110. prediction = model.predict(feature_vector)
  111. logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
  112. show_message(prediction[0])
  113. # 将消息推送到外部,看看图长什么样
  114. def show_message(market_direction):
  115. global order_book_snapshots, trade_data
  116. if len(order_book_snapshots) > 0 and len(trade_data) > 0:
  117. # 获取最新的订单簿数据和成交数据
  118. latest_order_book = order_book_snapshots[-1]
  119. latest_trade = trade_data[-1]
  120. # 提取asks和bids数据
  121. asks = [[float(price), 1 if market_direction == 1 else 0] for price, qty in latest_order_book['asks']]
  122. bids = [[float(price), 1 if market_direction == 0 else 0] for price, qty in latest_order_book['bids']]
  123. # 排序asks和bids数据
  124. asks_sorted = sorted(asks, key=lambda x: x[0])
  125. bids_sorted = sorted(bids, key=lambda x: x[0], reverse=True)
  126. last_price = latest_trade['price']
  127. last_qty = latest_trade['qty']
  128. side = latest_trade['side']
  129. # 生成数据字典
  130. data = {
  131. "asks": asks_sorted,
  132. "bids": bids_sorted,
  133. "last_price": last_price,
  134. "last_qty": last_qty,
  135. "side": side,
  136. "time": int(time.time() * 1000)
  137. }
  138. # 将数据放入消息队列
  139. messages.put(data)