| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- import json
- import numpy as np
- import pandas as pd
- import queue
- import threading
- from logger_config import logger
- from collections import deque
- from sklearn.linear_model import LogisticRegression
- # 假设我们有一个数据流,订单簿和成交数据
- order_book_snapshots = deque(maxlen=10000) # 存储过去100ms的订单簿快照
- trade_data = deque(maxlen=10000) # 存储过去100ms的成交数据
- # 数据积累的阈值
- DATA_THRESHOLD = 20
- model = LogisticRegression() # 初始化模型
- model_trained = False # 标记模型是否已经被训练
- messages = queue.Queue() # 创建一个线程安全队列
- stop_event = threading.Event()
- def on_message_trade(_ws, message):
- global trade_data
- json_message = json.loads(message)
- trade = {
- 'price': float(json_message['data']['p']),
- 'qty': float(json_message['data']['q']),
- 'timestamp': pd.to_datetime(json_message['data']['T'], unit='ms'),
- 'side': 'sell' if json_message['data']['m'] else 'buy'
- }
- trade_data.append(trade)
- check_and_train_model()
- predict_market_direction()
- def on_message_depth(_ws, message):
- global order_book_snapshots
- json_message = json.loads(message)
- bids = json_message['data']['b'][:10] # Top 10 bids
- asks = json_message['data']['a'][:10] # Top 10 asks
- timestamp = pd.to_datetime(json_message['data']['E'], unit='ms')
- order_book_snapshots.append({
- 'bids': bids,
- 'asks': asks,
- 'timestamp': timestamp
- })
- check_and_train_model()
- predict_market_direction()
- def extract_features(order_book, trade):
- # 计算买卖盘差距(spread)
- best_bid = float(order_book['bids'][0][0])
- best_ask = float(order_book['asks'][0][0])
- spread = best_ask - best_bid
- # 计算买卖盘深度
- bid_depth = sum(float(bid[1]) for bid in order_book['bids'])
- ask_depth = sum(float(ask[1]) for ask in order_book['asks'])
- # 计算成交量和方向
- trade_volume = trade['qty']
- trade_side = 1 if trade['side'] == 'buy' else -1
- features = {
- 'spread': spread,
- 'bid_depth': bid_depth,
- 'ask_depth': ask_depth,
- 'trade_volume': trade_volume,
- 'trade_side': trade_side
- }
- return features
- def prepare_training_data():
- # 提取特征和标签
- X = []
- y = []
- for i in range(len(order_book_snapshots) - 1):
- if i + 1 >= len(order_book_snapshots) or i >= len(trade_data):
- break
- current_order_book = order_book_snapshots[i]
- current_trade = trade_data[i]
- future_order_book = order_book_snapshots[i + 1]
- # 提取当前的特征
- features = extract_features(current_order_book, current_trade)
- X.append(list(features.values()))
- # 生成标签
- current_price = float(current_order_book['bids'][0][0])
- future_price = float(future_order_book['bids'][0][0])
- label = generate_label(current_price, future_price)
- y.append(label)
- # 将特征和标签转换为NumPy数组
- X_train = np.array(X)
- y_train = np.array(y)
- return X_train, y_train
- def generate_label(current_price, future_price):
- return 1 if future_price > current_price else 0
- def check_and_train_model():
- global model_trained
- if len(order_book_snapshots) >= DATA_THRESHOLD and len(trade_data) >= DATA_THRESHOLD:
- X_train, y_train = prepare_training_data()
- if len(X_train) > 0 and len(y_train) > 0:
- model.fit(X_train, y_train)
- model_trained = True # 标记模型已经被训练
- logger.info("Model trained with %d samples", len(X_train))
- else:
- logger.info("Insufficient data to train the model")
- else:
- logger.info("Not enough data to train the model: %d order book snapshots, %d trades", len(order_book_snapshots), len(trade_data))
- def predict_market_direction():
- global model_trained
- if len(order_book_snapshots) == 0 or len(trade_data) == 0:
- logger.info("Not enough data to make a prediction")
- return
- if not model_trained:
- logger.info("Model is not trained yet")
- return
- features = extract_features(order_book_snapshots[-1], trade_data[-1])
- if features is not None:
- feature_vector = np.array([list(features.values())])
- prediction = model.predict(feature_vector)
- logger.info("Predicted Market Direction: %s", "Up" if prediction[0] == 1 else "Down")
|