|
|
@@ -1,10 +1,8 @@
|
|
|
import json
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
-import time
|
|
|
import queue
|
|
|
import threading
|
|
|
-from datetime import datetime
|
|
|
from logger_config import logger
|
|
|
from collections import deque
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
@@ -14,7 +12,7 @@ order_book_snapshots = deque(maxlen=10000) # 存储过去100ms的订单簿快
|
|
|
trade_data = deque(maxlen=10000) # 存储过去100ms的成交数据
|
|
|
|
|
|
# 数据积累的阈值
|
|
|
-DATA_THRESHOLD = 10
|
|
|
+DATA_THRESHOLD = 20
|
|
|
model = LogisticRegression() # 初始化模型
|
|
|
model_trained = False # 标记模型是否已经被训练
|
|
|
messages = queue.Queue() # 创建一个线程安全队列
|
|
|
@@ -32,8 +30,8 @@ def on_message_trade(_ws, message):
|
|
|
'side': 'sell' if json_message['data']['m'] else 'buy'
|
|
|
}
|
|
|
trade_data.append(trade)
|
|
|
- predict_market_direction()
|
|
|
check_and_train_model()
|
|
|
+ predict_market_direction()
|
|
|
|
|
|
|
|
|
def on_message_depth(_ws, message):
|
|
|
@@ -47,8 +45,8 @@ def on_message_depth(_ws, message):
|
|
|
'asks': asks,
|
|
|
'timestamp': timestamp
|
|
|
})
|
|
|
- predict_market_direction()
|
|
|
check_and_train_model()
|
|
|
+ predict_market_direction()
|
|
|
|
|
|
|
|
|
def extract_features(order_book, trade):
|
|
|
@@ -82,6 +80,8 @@ def prepare_training_data():
|
|
|
y = []
|
|
|
|
|
|
for i in range(len(order_book_snapshots) - 1):
|
|
|
+ if i + 1 >= len(order_book_snapshots) or i >= len(trade_data):
|
|
|
+ break
|
|
|
current_order_book = order_book_snapshots[i]
|
|
|
current_trade = trade_data[i]
|
|
|
future_order_book = order_book_snapshots[i + 1]
|
|
|
@@ -111,9 +111,14 @@ def check_and_train_model():
|
|
|
global model_trained
|
|
|
if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
|
|
|
X_train, y_train = prepare_training_data()
|
|
|
- model.fit(X_train, y_train)
|
|
|
- model_trained = True # 标记模型已经被训练
|
|
|
- logger.info("Model trained with %d samples", len(X_train))
|
|
|
+ if len(X_train) > 0 and len(y_train) > 0:
|
|
|
+ model.fit(X_train, y_train)
|
|
|
+ model_trained = True # 标记模型已经被训练
|
|
|
+ logger.info("Model trained with %d samples", len(X_train))
|
|
|
+ else:
|
|
|
+ logger.info("Insufficient data to train the model")
|
|
|
+ else:
|
|
|
+ logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots), len(trade_data))
|
|
|
|
|
|
|
|
|
def predict_market_direction():
|
|
|
@@ -130,4 +135,4 @@ def predict_market_direction():
|
|
|
if features is not None:
|
|
|
feature_vector = np.array([list(features.values())])
|
|
|
prediction = model.predict(feature_vector)
|
|
|
- logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
|
|
|
+ logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
|