Bläddra i källkod

状态转移的部分好像搞完了。

skyffire 1 år sedan
förälder
incheckning
4b3b2d1ae5
1 ändrade filer med 20 tillägg och 16 borttagningar
  1. 20 16
      binance_gp_demo.py

+ 20 - 16
binance_gp_demo.py

@@ -33,8 +33,8 @@ logger.addHandler(handler)
 
 # 步骤二:订阅Binance的成交数据和订单簿数据
 # Binance WebSocket API URL
-SOCKET_TRADE = "wss://stream.binance.com:9443/ws/notusdt@trade"
-SOCKET_DEPTH = "wss://stream.binance.com:9443/ws/notusdt@depth20@100ms"
+SOCKET_TRADE = "wss://stream.binance.com:9443/ws/btcusdt@trade"
+SOCKET_DEPTH = "wss://stream.binance.com:9443/ws/btcusdt@depth20@100ms"
 
 # Initialize the DataFrame
 df_trades = pd.DataFrame(columns=['price', 'qty', 'timestamp'])
@@ -105,17 +105,19 @@ states = ['tight', 'normal', 'wide']
 transition_matrix = np.zeros((3, 3))
 
 
-# Function to update the transition matrix based on historical data
+# 定义函数来更新转移矩阵
 def update_transition_matrix(df):
     global transition_matrix
     for i in range(len(df) - 1):
         current_state = df['state'].iloc[i]
         next_state = df['state'].iloc[i + 1]
         transition_matrix[states.index(current_state), states.index(next_state)] += 1
-    transition_matrix = transition_matrix / transition_matrix.sum(axis=1, keepdims=True)
 
+    row_sums = transition_matrix.sum(axis=1, keepdims=True)
+    row_sums[row_sums == 0] = 1
+    transition_matrix = transition_matrix / row_sums
 
-# Function to classify the spread into states
+# 定义函数来分类价差状态
 def classify_spread(spread):
     if spread < 0.01:
         return 'tight'
@@ -124,21 +126,23 @@ def classify_spread(spread):
     else:
         return 'wide'
 
-
-# Function to calculate spread and classify it
+# 定义函数来计算价差并进行分类
 def calculate_and_classify_spread():
     global df_trades
     df_trades['spread'] = df_trades['price'].diff().abs()
     df_trades['state'] = df_trades['spread'].apply(classify_spread)
 
+# 定义周期性更新转移矩阵的函数
+stop_event = threading.Event()
 
-# Update the transition matrix periodically
 def update_transition_matrix_periodically():
-    calculate_and_classify_spread()
-    update_transition_matrix(df_trades)
-    logger.info("Transition Matrix:\n%s", transition_matrix)
-
-
-# Run the update function periodically, e.g., every minute
-transition_matrix_update_thread = threading.Timer(60.0, update_transition_matrix_periodically)
-transition_matrix_update_thread.start()
+    while not stop_event.is_set():
+        calculate_and_classify_spread()
+        update_transition_matrix(df_trades)
+        current_state = df_trades['state'].iloc[-1] if not df_trades.empty else 'unknown'
+        logger.info("Current State: %s\nTransition Matrix:\n%s\n", current_state, transition_matrix)
+        stop_event.wait(5)  # 每5秒更新一次
+
+# 创建并启动线程
+transition_matrix_update_thread = threading.Thread(target=update_transition_matrix_periodically)
+transition_matrix_update_thread.start()