Przeglądaj źródła

可以勉强预测方向的版本。

skyffire 1 rok temu
rodzic
commit
70fda57b8a
1 zmienionych plików z 14 dodań i 31 usunięć
  1. 14 31
      binance_order_flow/data_processing.py

+ 14 - 31
binance_order_flow/data_processing.py

@@ -1,5 +1,4 @@
 import json
-
 import numpy as np
 import pandas as pd
 import time
@@ -15,12 +14,14 @@ order_book_snapshots = deque(maxlen=10000)  # 存储过去100ms的订单簿快
 trade_data = deque(maxlen=10000)  # 存储过去100ms的成交数据
 
 # 数据积累的阈值
-DATA_THRESHOLD = 100
+DATA_THRESHOLD = 10
 model = LogisticRegression()  # 初始化模型
+model_trained = False  # 标记模型是否已经被训练
 messages = queue.Queue()  # 创建一个线程安全队列
 
 stop_event = threading.Event()
 
+
 def on_message_trade(_ws, message):
     global trade_data
     json_message = json.loads(message)
@@ -32,6 +33,7 @@ def on_message_trade(_ws, message):
     }
     trade_data.append(trade)
     predict_market_direction()
+    check_and_train_model()
 
 
 def on_message_depth(_ws, message):
@@ -46,6 +48,7 @@ def on_message_depth(_ws, message):
         'timestamp': timestamp
     })
     predict_market_direction()
+    check_and_train_model()
 
 
 def extract_features(order_book, trade):
@@ -105,46 +108,26 @@ def generate_label(current_price, future_price):
 
 
 def check_and_train_model():
+    global model_trained
     if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
         X_train, y_train = prepare_training_data()
         model.fit(X_train, y_train)
-        logger.info("Model trained with", len(X_train), "samples")
+        model_trained = True  # 标记模型已经被训练
+        logger.info("Model trained with %d samples", len(X_train))
 
 
 def predict_market_direction():
+    global model_trained
     if len(order_book_snapshots) == 0 or len(trade_data) == 0:
         logger.info("Not enough data to make a prediction")
         return
 
+    if not model_trained:
+        logger.info("Model is not trained yet")
+        return
+
     features = extract_features(order_book_snapshots[-1], trade_data[-1])
     if features is not None:
         feature_vector = np.array([list(features.values())])
         prediction = model.predict(feature_vector)
-        logger.info("Predicted Market Direction:", "Up" if prediction[0] == 1 else "Down")
-
-
-# # 计算成交概率
-# def show_message():
-#     global df_order_book, last_trade
-#
-#     if not df_order_book.empty and last_trade['price'] is not None:
-#         last_price = last_trade['price']
-#         asks = [[price, qty] for price, qty in
-#                 zip(df_order_book['ask_price'].iloc[0], df_order_book['ask_qty'].iloc[0])]
-#         bids = [[price, qty] for price, qty in
-#                 zip(df_order_book['bid_price'].iloc[0], df_order_book['bid_qty'].iloc[0])]
-#
-#         asks_sorted = sorted(asks, key=lambda x: x[0])
-#         bids_sorted = sorted(bids, key=lambda x: x[0], reverse=True)
-#
-#         last_qty = last_trade['qty']
-#         side = last_trade['side']
-#         data = {
-#             "asks": asks_sorted,
-#             "bids": bids_sorted,
-#             "last_price": last_price,
-#             "last_qty": last_qty,
-#             "side": side,
-#             "time": int(time.time() * 1000)
-#         }
-#         messages.put(data)
+        logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")