Jelajahi Sumber

能预测方向的版本,

skyffire 1 tahun lalu
induk
melakukan
f1e74bf2ab
1 mengubah file dengan 14 tambahan dan 9 penghapusan
  1. 14 9
      binance_order_flow/data_processing.py

+ 14 - 9
binance_order_flow/data_processing.py

@@ -1,10 +1,8 @@
 import json
 import numpy as np
 import pandas as pd
-import time
 import queue
 import threading
-from datetime import datetime
 from logger_config import logger
 from collections import deque
 from sklearn.linear_model import LogisticRegression
@@ -14,7 +12,7 @@ order_book_snapshots = deque(maxlen=10000)  # 存储过去100ms的订单簿快
 trade_data = deque(maxlen=10000)  # 存储过去100ms的成交数据
 
 # 数据积累的阈值
-DATA_THRESHOLD = 10
+DATA_THRESHOLD = 20
 model = LogisticRegression()  # 初始化模型
 model_trained = False  # 标记模型是否已经被训练
 messages = queue.Queue()  # 创建一个线程安全队列
@@ -32,8 +30,8 @@ def on_message_trade(_ws, message):
         'side': 'sell' if json_message['data']['m'] else 'buy'
     }
     trade_data.append(trade)
-    predict_market_direction()
     check_and_train_model()
+    predict_market_direction()
 
 
 def on_message_depth(_ws, message):
@@ -47,8 +45,8 @@ def on_message_depth(_ws, message):
         'asks': asks,
         'timestamp': timestamp
     })
-    predict_market_direction()
     check_and_train_model()
+    predict_market_direction()
 
 
 def extract_features(order_book, trade):
@@ -82,6 +80,8 @@ def prepare_training_data():
     y = []
 
     for i in range(len(order_book_snapshots) - 1):
+        if i + 1 >= len(order_book_snapshots) or i >= len(trade_data):
+            break
         current_order_book = order_book_snapshots[i]
         current_trade = trade_data[i]
         future_order_book = order_book_snapshots[i + 1]
@@ -111,9 +111,14 @@ def check_and_train_model():
     global model_trained
     if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
         X_train, y_train = prepare_training_data()
-        model.fit(X_train, y_train)
-        model_trained = True  # 标记模型已经被训练
-        logger.info("Model trained with %d samples", len(X_train))
+        if len(X_train) > 0 and len(y_train) > 0:
+            model.fit(X_train, y_train)
+            model_trained = True  # 标记模型已经被训练
+            logger.info("Model trained with %d samples", len(X_train))
+        else:
+            logger.info("Insufficient data to train the model")
+    else:
+        logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots), len(trade_data))
 
 
 def predict_market_direction():
@@ -130,4 +135,4 @@ def predict_market_direction():
     if features is not None:
         feature_vector = np.array([list(features.values())])
         prediction = model.predict(feature_vector)
-        logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
+        logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")