|
|
@@ -14,6 +14,7 @@ from logger_config import logger
|
|
|
# 假设我们有一个数据流,订单簿和成交数据
|
|
|
order_book_snapshots = deque(maxlen=100) # 存储过去100ms的订单簿快照
|
|
|
trade_data = deque(maxlen=100) # 存储过去100ms的成交数据
|
|
|
+prediction_window = deque(maxlen=10) # 用于平滑预测结果的滑动窗口
|
|
|
|
|
|
# 数据积累的阈值
|
|
|
DATA_THRESHOLD = 20
|
|
|
@@ -123,7 +124,8 @@ def generate_label(current_price, future_price):
|
|
|
|
|
|
def balance_data(X, y):
|
|
|
# 将数据转换为 DataFrame,便于处理
|
|
|
- df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count', 'ask_count', 'timestamp'])
|
|
|
+ df = pd.DataFrame(X, columns=['spread', 'bid_depth', 'ask_depth', 'trade_volume', 'trade_side', 'bid_count',
|
|
|
+ 'ask_count', 'timestamp'])
|
|
|
df['label'] = y
|
|
|
|
|
|
# 分别获取两类数据
|
|
|
@@ -163,11 +165,12 @@ def check_and_train_model():
|
|
|
else:
|
|
|
logger.info("Insufficient data to train the model")
|
|
|
else:
|
|
|
- logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots), len(trade_data))
|
|
|
+ logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots),
|
|
|
+ len(trade_data))
|
|
|
|
|
|
|
|
|
def predict_market_direction():
|
|
|
- global model_trained
|
|
|
+ global model_trained, prediction_window
|
|
|
if len(order_book_snapshots) == 0 or len(trade_data) == 0:
|
|
|
return
|
|
|
|
|
|
@@ -177,9 +180,16 @@ def predict_market_direction():
|
|
|
features = extract_features(order_book_snapshots[-1], trade_data[-1])
|
|
|
if features is not None:
|
|
|
feature_vector = np.array([list(features.values())])
|
|
|
- prediction = model.predict(feature_vector)
|
|
|
- logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
|
|
|
- show_message(prediction[0])
|
|
|
+ prediction = model.predict(feature_vector)[0]
|
|
|
+
|
|
|
+ # 将预测结果添加到滑动窗口中
|
|
|
+ prediction_window.append(prediction)
|
|
|
+
|
|
|
+ # 根据滑动窗口中的多数结果确定当前的预测方向
|
|
|
+ if len(prediction_window) == prediction_window.maxlen:
|
|
|
+ smoothed_prediction = 1 if sum(prediction_window) > len(prediction_window) / 2 else 0
|
|
|
+ logger.info("Predicted Market Direction (smoothed): %s", "Up" if smoothed_prediction == 1 else "Down")
|
|
|
+ show_message(smoothed_prediction)
|
|
|
|
|
|
|
|
|
def show_message(market_direction):
|