ソースを参照

能大概计算A值和k值了,但是还有一些问题。

skyffire 1 年間 前
コミット
3abc2a6d45
2 ファイル変更72 行追加9 行削除
  1. 1 0
      .gitignore
  2. 71 9
      kappa/data_processing.py

+ 1 - 0
.gitignore

@@ -1 +1,2 @@
 .idea
+.ipynb_checkpoints

+ 71 - 9
kappa/data_processing.py

@@ -1,24 +1,41 @@
 import json
+from decimal import Decimal, getcontext
+
 import pandas as pd
 import threading
 from collections import deque
 from scipy.integrate import trapz
 import numpy as np
+from scipy.optimize import minimize
 from logger_config import logger
 
+# 设置全局精度
+getcontext().prec = 28
+
 # 假设我们有一个数据流,订单簿和成交数据
 order_book_snapshots = deque(maxlen=600)        # 存储过去600个订单簿快照
+spread_delta_snapshots = deque(maxlen=6000)     # 存储过去600个价差数据(最小变动价格的倍数)
 trade_snapshots = deque(maxlen=6000)            # 存储过去6000个成交数据
 
 stop_event = threading.Event()
 
 # 初始参数
-k_initial = 0.5
-A_initial = 1.0
+k_initial = 10
+A_initial = 100
 
 # 假设S0是初始的参考价格
 S0 = -1
 
+def get_tick_size_from_prices(ask_price, bid_price):
+    # 获取价格的小数位数
+    ask_decimal_places = len(str(ask_price).split('.')[1])
+    bid_decimal_places = len(str(bid_price).split('.')[1])
+
+    # 确定最小变动单元
+    tick_size = 10 ** -max(ask_decimal_places, bid_decimal_places)
+
+    return tick_size
+
 
 def on_message_trade(_ws, message):
     global trade_snapshots
@@ -34,16 +51,26 @@ def on_message_trade(_ws, message):
 
 
 def on_message_depth(_ws, message):
-    global order_book_snapshots
+    global order_book_snapshots, spread_delta_snapshots
     json_message = json.loads(message)
     bids = [[float(price), float(quantity)] for price, quantity in json_message['data']['b'][:10]]
     asks = [[float(price), float(quantity)] for price, quantity in json_message['data']['a'][:10]]
     timestamp = pd.to_datetime(json_message['data']['E'], unit='ms')
-    order_book_snapshots.append({
+    depth = {
         'bids': bids,
         'asks': asks,
         'timestamp': timestamp
-    })
+    }
+    order_book_snapshots.append(depth)
+
+    # 求价差
+    ask_price = Decimal(str(asks[0][0]))
+    bid_price = Decimal(str(bids[0][0]))
+    tick_size = get_tick_size_from_prices(ask_price, bid_price)
+    spread = float(ask_price - bid_price)
+    spread_delta = int(spread / tick_size)
+    spread_delta_snapshots.append(spread_delta)
+
     process_depth_data()
 
 
@@ -95,11 +122,31 @@ def estimate_lambda(waiting_times, T):
     return lambda_hat
 
 
+def objective_function(params, delta_max, log_lambda_hat_value, log_integral_phi_value):
+    """
+    目标函数 r(A, k)
+    :param params: 包含 A 和 k 的数组
+    :param delta_max: 最大的价格偏移
+    :param log_lambda_hat_value: log(λ(δ)) 的值
+    :param log_integral_phi_value: log(∫ φ(k, ξ) dξ) 的值
+    :return: 目标函数值
+    """
+    A, k = params
+    if A <= 0:
+        return 0
+
+    residuals = []
+    for delta in range(1, delta_max + 1):
+        residual = (log_lambda_hat_value + k * delta - np.log(A) - log_integral_phi_value) ** 2
+        residuals.append(residual)
+    return np.sum(residuals)
+
+
 def process_depth_data():
-    global order_book_snapshots, trade_snapshots
+    global order_book_snapshots, trade_snapshots, spread_delta_snapshots
 
     # 数据预热,至少10条深度数据以及100条成交数据才能用于计算
-    if len(order_book_snapshots) < 10 and len(trade_snapshots) < 100:
+    if len(order_book_snapshots) < 10 or len(trade_snapshots) < 100:
         return
 
     global k_initial, A_initial, S0
@@ -126,8 +173,23 @@ def process_depth_data():
     # 时间窗口的大小
     T = pd.to_datetime(100, unit='ms') - pd.to_datetime(0, unit='ms')
 
+    # 计算 λ(δ) 的估计值
     lambda_hat = estimate_lambda(waiting_times, T)
-    # logger.info("λ(δ) 的值: " + str(lambda_hat) + "log(∫ φ(k, ξ) dξ) 的值: " + str(log_integral_phi_value))
+
+    # 计算 log(λ(δ))
     log_lambda_hat_value = np.log(lambda_hat)
-    logger.info("log(λ(δ)) 的值: " + str(log_lambda_hat_value) + "log(∫ φ(k, ξ) dξ) 的值: " + str(log_integral_phi_value))
 
+    # ========================== 校准 A 和 k =============================
+    delta_max = np.max(spread_delta_snapshots)
+
+    # 优化目标函数以找到最优的 A 和 k
+    result = minimize(objective_function, np.array([A_initial, k_initial]),
+                      args=(delta_max, log_lambda_hat_value, log_integral_phi_value))
+
+    if result.success:
+        A_optimal, k_optimal = result.x
+        logger.info(f"Optimal A: {A_optimal}, Optimal k: {k_optimal}")
+    else:
+        logger.error("Optimization failed")
+
+    # logger.info("log(λ(δ)) 的值: " + str(log_lambda_hat_value) + " log(∫ φ(k, ξ) dξ) 的值: " + str(log_integral_phi_value))