|
|
@@ -1,5 +1,4 @@
|
|
|
import json
|
|
|
-
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import time
|
|
|
@@ -15,12 +14,14 @@ order_book_snapshots = deque(maxlen=10000) # 存储过去100ms的订单簿快
|
|
|
trade_data = deque(maxlen=10000) # 存储过去100ms的成交数据
|
|
|
|
|
|
# 数据积累的阈值
|
|
|
-DATA_THRESHOLD = 100
|
|
|
+DATA_THRESHOLD = 10
|
|
|
model = LogisticRegression() # 初始化模型
|
|
|
+model_trained = False # 标记模型是否已经被训练
|
|
|
messages = queue.Queue() # 创建一个线程安全队列
|
|
|
|
|
|
stop_event = threading.Event()
|
|
|
|
|
|
+
|
|
|
def on_message_trade(_ws, message):
|
|
|
global trade_data
|
|
|
json_message = json.loads(message)
|
|
|
@@ -32,6 +33,7 @@ def on_message_trade(_ws, message):
|
|
|
}
|
|
|
trade_data.append(trade)
|
|
|
predict_market_direction()
|
|
|
+ check_and_train_model()
|
|
|
|
|
|
|
|
|
def on_message_depth(_ws, message):
|
|
|
@@ -46,6 +48,7 @@ def on_message_depth(_ws, message):
|
|
|
'timestamp': timestamp
|
|
|
})
|
|
|
predict_market_direction()
|
|
|
+ check_and_train_model()
|
|
|
|
|
|
|
|
|
def extract_features(order_book, trade):
|
|
|
@@ -105,46 +108,26 @@ def generate_label(current_price, future_price):
|
|
|
|
|
|
|
|
|
def check_and_train_model():
|
|
|
+ global model_trained
|
|
|
if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
|
|
|
X_train, y_train = prepare_training_data()
|
|
|
model.fit(X_train, y_train)
|
|
|
- logger.info("Model trained with", len(X_train), "samples")
|
|
|
+ model_trained = True # 标记模型已经被训练
|
|
|
+ logger.info("Model trained with %d samples", len(X_train))
|
|
|
|
|
|
|
|
|
def predict_market_direction():
|
|
|
+ global model_trained
|
|
|
if len(order_book_snapshots) == 0 or len(trade_data) == 0:
|
|
|
logger.info("Not enough data to make a prediction")
|
|
|
return
|
|
|
|
|
|
+ if not model_trained:
|
|
|
+ logger.info("Model is not trained yet")
|
|
|
+ return
|
|
|
+
|
|
|
features = extract_features(order_book_snapshots[-1], trade_data[-1])
|
|
|
if features is not None:
|
|
|
feature_vector = np.array([list(features.values())])
|
|
|
prediction = model.predict(feature_vector)
|
|
|
- logger.info("Predicted Market Direction:", "Up" if prediction[0] == 1 else "Down")
|
|
|
-
|
|
|
-
|
|
|
-# # 计算成交概率
|
|
|
-# def show_message():
|
|
|
-# global df_order_book, last_trade
|
|
|
-#
|
|
|
-# if not df_order_book.empty and last_trade['price'] is not None:
|
|
|
-# last_price = last_trade['price']
|
|
|
-# asks = [[price, qty] for price, qty in
|
|
|
-# zip(df_order_book['ask_price'].iloc[0], df_order_book['ask_qty'].iloc[0])]
|
|
|
-# bids = [[price, qty] for price, qty in
|
|
|
-# zip(df_order_book['bid_price'].iloc[0], df_order_book['bid_qty'].iloc[0])]
|
|
|
-#
|
|
|
-# asks_sorted = sorted(asks, key=lambda x: x[0])
|
|
|
-# bids_sorted = sorted(bids, key=lambda x: x[0], reverse=True)
|
|
|
-#
|
|
|
-# last_qty = last_trade['qty']
|
|
|
-# side = last_trade['side']
|
|
|
-# data = {
|
|
|
-# "asks": asks_sorted,
|
|
|
-# "bids": bids_sorted,
|
|
|
-# "last_price": last_price,
|
|
|
-# "last_qty": last_qty,
|
|
|
-# "side": side,
|
|
|
-# "time": int(time.time() * 1000)
|
|
|
-# }
|
|
|
-# messages.put(data)
|
|
|
+ logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
|