|
|
@@ -1,13 +1,15 @@
|
|
|
import json
|
|
|
import time
|
|
|
-
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import queue
|
|
|
import threading
|
|
|
-from logger_config import logger
|
|
|
from collections import deque
|
|
|
-from sklearn.linear_model import LogisticRegression
|
|
|
+from sklearn.model_selection import train_test_split
|
|
|
+from sklearn.ensemble import RandomForestClassifier
|
|
|
+from sklearn.metrics import classification_report
|
|
|
+from sklearn.utils import resample
|
|
|
+from logger_config import logger
|
|
|
|
|
|
# 假设我们有一个数据流,订单簿和成交数据
|
|
|
order_book_snapshots = deque(maxlen=100) # 存储过去100ms的订单簿快照
|
|
|
@@ -15,7 +17,7 @@ trade_data = deque(maxlen=100) # 存储过去100ms的成交数据
|
|
|
|
|
|
# 数据积累的阈值
|
|
|
DATA_THRESHOLD = 20
|
|
|
-model = LogisticRegression() # 初始化模型
|
|
|
+model = RandomForestClassifier(n_estimators=100) # 使用随机森林模型
|
|
|
model_trained = False # 标记模型是否已经被训练
|
|
|
messages = queue.Queue() # 创建一个线程安全队列
|
|
|
|
|
|
@@ -65,12 +67,22 @@ def extract_features(order_book, trade):
|
|
|
trade_volume = trade['qty']
|
|
|
trade_side = 1 if trade['side'] == 'buy' else -1
|
|
|
|
|
|
+ # 计算买卖盘数量
|
|
|
+ bid_count = len(order_book['bids'])
|
|
|
+ ask_count = len(order_book['asks'])
|
|
|
+
|
|
|
+ # 计算时间特征
|
|
|
+ timestamp = trade['timestamp'].timestamp()
|
|
|
+
|
|
|
features = {
|
|
|
'spread': spread,
|
|
|
'bid_depth': bid_depth,
|
|
|
'ask_depth': ask_depth,
|
|
|
'trade_volume': trade_volume,
|
|
|
- 'trade_side': trade_side
|
|
|
+ 'trade_side': trade_side,
|
|
|
+ 'bid_count': bid_count,
|
|
|
+ 'ask_count': ask_count,
|
|
|
+ 'timestamp': timestamp
|
|
|
}
|
|
|
|
|
|
return features
|
|
|
@@ -109,14 +121,45 @@ def generate_label(current_price, future_price):
|
|
|
return 1 if future_price > current_price else 0
|
|
|
|
|
|
|
|
|
+def balance_data(X, y):
|
|
|
+ # 将数据转换为 DataFrame,便于处理
|
|
|
+ df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count', 'ask_count', 'timestamp'])
|
|
|
+ df['label'] = y
|
|
|
+
|
|
|
+ # 分别获取两类数据
|
|
|
+ df_majority = df[df.label == 0]
|
|
|
+ df_minority = df[df.label == 1]
|
|
|
+
|
|
|
+ if len(df_minority) == 0 or len(df_majority) == 0:
|
|
|
+ return X, y # 如果某一类数据为空,返回原数据
|
|
|
+
|
|
|
+ # 使用下采样或上采样来平衡数据
|
|
|
+ if len(df_majority) > len(df_minority):
|
|
|
+ df_majority_downsampled = resample(df_majority, replace=False, n_samples=len(df_minority), random_state=42)
|
|
|
+ df_balanced = pd.concat([df_majority_downsampled, df_minority])
|
|
|
+ else:
|
|
|
+ df_minority_upsampled = resample(df_minority, replace=True, n_samples=len(df_majority), random_state=42)
|
|
|
+ df_balanced = pd.concat([df_majority, df_minority_upsampled])
|
|
|
+
|
|
|
+ # 提取平衡后的特征和标签
|
|
|
+ X_balanced = df_balanced.drop('label', axis=1).values
|
|
|
+ y_balanced = df_balanced['label'].values
|
|
|
+
|
|
|
+ return X_balanced, y_balanced
|
|
|
+
|
|
|
+
|
|
|
def check_and_train_model():
|
|
|
global model_trained
|
|
|
if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
|
|
|
X_train, y_train = prepare_training_data()
|
|
|
if len(X_train) > 0 and len(y_train) > 0:
|
|
|
+ X_train, y_train = balance_data(X_train, y_train) # 平衡数据
|
|
|
+ X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
|
|
|
model.fit(X_train, y_train)
|
|
|
model_trained = True # 标记模型已经被训练
|
|
|
- # logger.info("Model trained with %d samples", len(X_train))
|
|
|
+ logger.info("Model trained with %d samples", len(X_train))
|
|
|
+ y_pred = model.predict(X_test)
|
|
|
+ logger.info("\n" + classification_report(y_test, y_pred, zero_division=0))
|
|
|
else:
|
|
|
logger.info("Insufficient data to train the model")
|
|
|
else:
|
|
|
@@ -126,11 +169,9 @@ def check_and_train_model():
|
|
|
def predict_market_direction():
|
|
|
global model_trained
|
|
|
if len(order_book_snapshots) == 0 or len(trade_data) == 0:
|
|
|
- # logger.info("Not enough data to make a prediction")
|
|
|
return
|
|
|
|
|
|
if not model_trained:
|
|
|
- # logger.info("Model is not trained yet")
|
|
|
return
|
|
|
|
|
|
features = extract_features(order_book_snapshots[-1], trade_data[-1])
|
|
|
@@ -141,7 +182,6 @@ def predict_market_direction():
|
|
|
show_message(prediction[0])
|
|
|
|
|
|
|
|
|
-# 将消息推送到外部,看看图长什么样
|
|
|
def show_message(market_direction):
|
|
|
global order_book_snapshots, trade_data
|
|
|
|