Browse Source

平滑的方向处理。

skyffire 1 year ago
parent
commit
b7bca5b081
1 changed files with 16 additions and 6 deletions
  1. 16 6
      binance_order_flow/data_processing.py

+ 16 - 6
binance_order_flow/data_processing.py

@@ -14,6 +14,7 @@ from logger_config import logger
 # 假设我们有一个数据流,订单簿和成交数据
 order_book_snapshots = deque(maxlen=100)  # 存储过去100ms的订单簿快照
 trade_data = deque(maxlen=100)  # 存储过去100ms的成交数据
+prediction_window = deque(maxlen=10)  # 用于平滑预测结果的滑动窗口
 
 # 数据积累的阈值
 DATA_THRESHOLD = 20
@@ -123,7 +124,8 @@ def generate_label(current_price, future_price):
 
 def balance_data(X, y):
     # 将数据转换为 DataFrame,便于处理
-    df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count', 'ask_count', 'timestamp'])
+    df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count',
+                                  'ask_count', 'timestamp'])
     df['label'] = y
 
     # 分别获取两类数据
@@ -163,11 +165,12 @@ def check_and_train_model():
         else:
             logger.info("Insufficient data to train the model")
     else:
-        logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots), len(trade_data))
+        logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots),
+                    len(trade_data))
 
 
 def predict_market_direction():
-    global model_trained
+    global model_trained, prediction_window
     if len(order_book_snapshots) == 0 or len(trade_data) == 0:
         return
 
@@ -177,9 +180,16 @@ def predict_market_direction():
     features = extract_features(order_book_snapshots[-1], trade_data[-1])
     if features is not None:
         feature_vector = np.array([list(features.values())])
-        prediction = model.predict(feature_vector)
-        logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
-        show_message(prediction[0])
+        prediction = model.predict(feature_vector)[0]
+
+        # 将预测结果添加到滑动窗口中
+        prediction_window.append(prediction)
+
+        # 根据滑动窗口中的多数结果确定当前的预测方向
+        if len(prediction_window) == prediction_window.maxlen:
+            smoothed_prediction = 1 if sum(prediction_window) > len(prediction_window) / 2 else 0
+            logger.info("Predicted Market Direction (smoothed): %s", "Up" if smoothed_prediction == 1 else "Down")
+            show_message(smoothed_prediction)
 
 
 def show_message(market_direction):