skyffire 1 жил өмнө
parent
commit
e34f01ed10

+ 49 - 9
binance_order_flow/data_processing.py

@@ -1,13 +1,15 @@
 import json
 import time
-
 import numpy as np
 import pandas as pd
 import queue
 import threading
-from logger_config import logger
 from collections import deque
-from sklearn.linear_model import LogisticRegression
+from sklearn.model_selection import train_test_split
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.metrics import classification_report
+from sklearn.utils import resample
+from logger_config import logger
 
 # 假设我们有一个数据流,订单簿和成交数据
 order_book_snapshots = deque(maxlen=100)  # 存储过去100ms的订单簿快照
@@ -15,7 +17,7 @@ trade_data = deque(maxlen=100)  # 存储过去100ms的成交数据
 
 # 数据积累的阈值
 DATA_THRESHOLD = 20
-model = LogisticRegression()  # 初始化模型
+model = RandomForestClassifier(n_estimators=100)  # 使用随机森林模型
 model_trained = False  # 标记模型是否已经被训练
 messages = queue.Queue()  # 创建一个线程安全队列
 
@@ -65,12 +67,22 @@ def extract_features(order_book, trade):
     trade_volume = trade['qty']
     trade_side = 1 if trade['side'] == 'buy' else -1
 
+    # 计算买卖盘数量
+    bid_count = len(order_book['bids'])
+    ask_count = len(order_book['asks'])
+
+    # 计算时间特征
+    timestamp = trade['timestamp'].timestamp()
+
     features = {
         'spread': spread,
         'bid_depth': bid_depth,
         'ask_depth': ask_depth,
         'trade_volume': trade_volume,
-        'trade_side': trade_side
+        'trade_side': trade_side,
+        'bid_count': bid_count,
+        'ask_count': ask_count,
+        'timestamp': timestamp
     }
 
     return features
@@ -109,14 +121,45 @@ def generate_label(current_price, future_price):
     return 1 if future_price > current_price else 0
 
 
+def balance_data(X, y):
+    # 将数据转换为 DataFrame,便于处理
+    df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count', 'ask_count', 'timestamp'])
+    df['label'] = y
+
+    # 分别获取两类数据
+    df_majority = df[df.label == 0]
+    df_minority = df[df.label == 1]
+
+    if len(df_minority) == 0 or len(df_majority) == 0:
+        return X, y  # 如果某一类数据为空,返回原数据
+
+    # 使用下采样或上采样来平衡数据
+    if len(df_majority) > len(df_minority):
+        df_majority_downsampled = resample(df_majority, replace=False, n_samples=len(df_minority), random_state=42)
+        df_balanced = pd.concat([df_majority_downsampled, df_minority])
+    else:
+        df_minority_upsampled = resample(df_minority, replace=True, n_samples=len(df_majority), random_state=42)
+        df_balanced = pd.concat([df_majority, df_minority_upsampled])
+
+    # 提取平衡后的特征和标签
+    X_balanced = df_balanced.drop('label', axis=1).values
+    y_balanced = df_balanced['label'].values
+
+    return X_balanced, y_balanced
+
+
 def check_and_train_model():
     global model_trained
     if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
         X_train, y_train = prepare_training_data()
         if len(X_train) > 0 and len(y_train) > 0:
+            X_train, y_train = balance_data(X_train, y_train)  # 平衡数据
+            X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
             model.fit(X_train, y_train)
             model_trained = True  # 标记模型已经被训练
-            # logger.info("Model trained with %d samples", len(X_train))
+            logger.info("Model trained with %d samples", len(X_train))
+            y_pred = model.predict(X_test)
+            logger.info("\n" + classification_report(y_test, y_pred, zero_division=0))
         else:
             logger.info("Insufficient data to train the model")
     else:
@@ -126,11 +169,9 @@ def check_and_train_model():
 def predict_market_direction():
     global model_trained
     if len(order_book_snapshots) == 0 or len(trade_data) == 0:
-        # logger.info("Not enough data to make a prediction")
         return
 
     if not model_trained:
-        # logger.info("Model is not trained yet")
         return
 
     features = extract_features(order_book_snapshots[-1], trade_data[-1])
@@ -141,7 +182,6 @@ def predict_market_direction():
         show_message(prediction[0])
 
 
-# 将消息推送到外部,看看图长什么样
 def show_message(market_direction):
     global order_book_snapshots, trade_data