|
|
@@ -5,10 +5,6 @@ import pandas as pd
|
|
|
import queue
|
|
|
import threading
|
|
|
from collections import deque
|
|
|
-from sklearn.model_selection import train_test_split
|
|
|
-from sklearn.ensemble import RandomForestClassifier
|
|
|
-from sklearn.metrics import classification_report
|
|
|
-from sklearn.utils import resample
|
|
|
from logger_config import logger
|
|
|
|
|
|
# 假设我们有一个数据流,订单簿和成交数据
|
|
|
@@ -18,8 +14,6 @@ prediction_window = deque(maxlen=10) # 用于平滑预测结果的滑动窗口
|
|
|
|
|
|
# 数据积累的阈值
|
|
|
DATA_THRESHOLD = 20
|
|
|
-model = RandomForestClassifier(n_estimators=100) # 使用随机森林模型
|
|
|
-model_trained = False # 标记模型是否已经被训练
|
|
|
messages = queue.Queue() # 创建一个线程安全队列
|
|
|
|
|
|
stop_event = threading.Event()
|
|
|
@@ -35,7 +29,6 @@ def on_message_trade(_ws, message):
|
|
|
'side': 'sell' if json_message['data']['m'] else 'buy'
|
|
|
}
|
|
|
trade_data.append(trade)
|
|
|
- check_and_train_model()
|
|
|
predict_market_direction()
|
|
|
|
|
|
|
|
|
@@ -50,7 +43,6 @@ def on_message_depth(_ws, message):
|
|
|
'asks': asks,
|
|
|
'timestamp': timestamp
|
|
|
})
|
|
|
- check_and_train_model()
|
|
|
predict_market_direction()
|
|
|
|
|
|
|
|
|
@@ -89,107 +81,29 @@ def extract_features(order_book, trade):
|
|
|
return features
|
|
|
|
|
|
|
|
|
-def prepare_training_data():
|
|
|
- # 提取特征和标签
|
|
|
- X = []
|
|
|
- y = []
|
|
|
-
|
|
|
- for i in range(len(order_book_snapshots) - 1):
|
|
|
- if i + 1 >= len(order_book_snapshots) or i >= len(trade_data):
|
|
|
- break
|
|
|
- current_order_book = order_book_snapshots[i]
|
|
|
- current_trade = trade_data[i]
|
|
|
- future_order_book = order_book_snapshots[i + 1]
|
|
|
-
|
|
|
- # 提取当前的特征
|
|
|
- features = extract_features(current_order_book, current_trade)
|
|
|
- X.append(list(features.values()))
|
|
|
-
|
|
|
- # 生成标签
|
|
|
- current_price = float(current_order_book['bids'][0][0])
|
|
|
- future_price = float(future_order_book['bids'][0][0])
|
|
|
- label = generate_label(current_price, future_price)
|
|
|
- y.append(label)
|
|
|
-
|
|
|
- # 将特征和标签转换为NumPy数组
|
|
|
- X_train = np.array(X)
|
|
|
- y_train = np.array(y)
|
|
|
-
|
|
|
- return X_train, y_train
|
|
|
-
|
|
|
-
|
|
|
def generate_label(current_price, future_price):
|
|
|
return 1 if future_price > current_price else 0
|
|
|
|
|
|
|
|
|
-def balance_data(X, y):
|
|
|
- # 将数据转换为 DataFrame,便于处理
|
|
|
- df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count',
|
|
|
- 'ask_count', 'timestamp'])
|
|
|
- df['label'] = y
|
|
|
-
|
|
|
- # 分别获取两类数据
|
|
|
- df_majority = df[df.label == 0]
|
|
|
- df_minority = df[df.label == 1]
|
|
|
-
|
|
|
- if len(df_minority) == 0 or len(df_majority) == 0:
|
|
|
- return X, y # 如果某一类数据为空,返回原数据
|
|
|
-
|
|
|
- # 使用下采样或上采样来平衡数据
|
|
|
- if len(df_majority) > len(df_minority):
|
|
|
- df_majority_downsampled = resample(df_majority, replace=False, n_samples=len(df_minority), random_state=42)
|
|
|
- df_balanced = pd.concat([df_majority_downsampled, df_minority])
|
|
|
- else:
|
|
|
- df_minority_upsampled = resample(df_minority, replace=True, n_samples=len(df_majority), random_state=42)
|
|
|
- df_balanced = pd.concat([df_majority, df_minority_upsampled])
|
|
|
-
|
|
|
- # 提取平衡后的特征和标签
|
|
|
- X_balanced = df_balanced.drop('label', axis=1).values
|
|
|
- y_balanced = df_balanced['label'].values
|
|
|
-
|
|
|
- return X_balanced, y_balanced
|
|
|
-
|
|
|
-
|
|
|
-def check_and_train_model():
|
|
|
- global model_trained
|
|
|
- if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
|
|
|
- X_train, y_train = prepare_training_data()
|
|
|
- if len(X_train) > 0 and len(y_train) > 0:
|
|
|
- X_train, y_train = balance_data(X_train, y_train) # 平衡数据
|
|
|
- X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
|
|
|
- model.fit(X_train, y_train)
|
|
|
- model_trained = True # 标记模型已经被训练
|
|
|
- logger.info("Model trained with %d samples", len(X_train))
|
|
|
- y_pred = model.predict(X_test)
|
|
|
- logger.info("\n" + classification_report(y_test, y_pred, zero_division=0))
|
|
|
- else:
|
|
|
- logger.info("Insufficient data to train the model")
|
|
|
- else:
|
|
|
- logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots),
|
|
|
- len(trade_data))
|
|
|
-
|
|
|
-
|
|
|
def predict_market_direction():
|
|
|
- global model_trained, prediction_window
|
|
|
+ global prediction_window
|
|
|
if len(order_book_snapshots) == 0 or len(trade_data) == 0:
|
|
|
return
|
|
|
|
|
|
- if not model_trained:
|
|
|
- return
|
|
|
-
|
|
|
- features = extract_features(order_book_snapshots[-1], trade_data[-1])
|
|
|
- if features is not None:
|
|
|
- feature_vector = np.array([list(features.values())])
|
|
|
- prediction = model.predict(feature_vector)[0]
|
|
|
+ # 模拟一个简单的预测逻辑:如果卖一价高于买一价,则预测价格下跌,否则预测价格上涨
|
|
|
+ latest_order_book = order_book_snapshots[-1]
|
|
|
+ best_bid = float(latest_order_book['bids'][0][0])
|
|
|
+ best_ask = float(latest_order_book['asks'][0][0])
|
|
|
+ prediction = 1 if best_ask > best_bid else 0
|
|
|
|
|
|
- # 将预测结果添加到滑动窗口中
|
|
|
- prediction_window.append(prediction)
|
|
|
+ # 将预测结果添加到滑动窗口中
|
|
|
+ prediction_window.append(prediction)
|
|
|
|
|
|
- # 根据滑动窗口中的多数结果确定当前的预测方向
|
|
|
- if len(prediction_window) == prediction_window.maxlen:
|
|
|
- smoothed_prediction = 1 if sum(prediction_window) > len(prediction_window) / 2 else 0
|
|
|
- logger.info("Predicted Market Direction (smoothed): %s", "Up" if smoothed_prediction == 1 else "Down")
|
|
|
- show_message(smoothed_prediction)
|
|
|
+ # 根据滑动窗口中的多数结果确定当前的预测方向
|
|
|
+ if len(prediction_window) == prediction_window.maxlen:
|
|
|
+ smoothed_prediction = 1 if sum(prediction_window) > len(prediction_window) / 2 else 0
|
|
|
+ logger.info("Predicted Market Direction (smoothed): %s", "Up" if smoothed_prediction == 1 else "Down")
|
|
|
+ show_message(smoothed_prediction)
|
|
|
|
|
|
|
|
|
def show_message(market_direction):
|