skyffire 1 год назад
Родитель
Сommit
a243c0ac88
1 измененных файлов с 13 добавлено и 8 удалено
  1. 13 8
      kappa/data_processing.py

+ 13 - 8
kappa/data_processing.py

@@ -23,6 +23,10 @@ stop_event = threading.Event()
 k_initial = 10
 A_initial = 100
 
+# 定义参数范围
+bounds = [(10, 1000.0),             # A 的范围
+          (0.01, 100.0)]            # k 的范围
+
 # 假设S0是初始的参考价格
 S0 = -1
 
@@ -132,24 +136,23 @@ def objective_function(params, delta_max, log_lambda_hat_value, log_integral_phi
     :return: 目标函数值
     """
     A, k = params
-    if A <= 0:
-        return 0
 
     residuals = []
     for delta in range(1, delta_max + 1):
-        residual = (log_lambda_hat_value + k * delta - np.log(A) - log_integral_phi_value) ** 2
+        rst = (log_lambda_hat_value + k * delta - np.log(A) - log_integral_phi_value)
+        residual = rst ** 2
         residuals.append(residual)
     return np.sum(residuals)
 
 
 def process_depth_data():
     global order_book_snapshots, trade_snapshots, spread_delta_snapshots
+    global k_initial, A_initial, S0
 
     # 数据预热,至少10条深度数据以及100条成交数据才能用于计算
     if len(order_book_snapshots) < 10 or len(trade_snapshots) < 100:
         return
 
-    global k_initial, A_initial, S0
     S_values = [((snapshot['bids'][0][0] + snapshot['asks'][0][0]) / 2) for snapshot in order_book_snapshots]
 
     if S0 < 0:
@@ -184,12 +187,14 @@ def process_depth_data():
 
     # 优化目标函数以找到最优的 A 和 k
     result = minimize(objective_function, np.array([A_initial, k_initial]),
-                      args=(delta_max, log_lambda_hat_value, log_integral_phi_value))
+                      args=(delta_max, log_lambda_hat_value, log_integral_phi_value),
+                      bounds=bounds)
 
     if result.success:
-        A_optimal, k_optimal = result.x
-        logger.info(f"Optimal A: {A_optimal}, Optimal k: {k_optimal}")
+        A_initial, k_initial = result.x
+
+        logger.info(f"Optimal A: {A_initial}, Optimal k: {k_initial}")
     else:
         logger.error("Optimization failed")
 
-    # logger.info("log(λ(δ)) 的值: " + str(log_lambda_hat_value) + " log(∫ φ(k, ξ) dξ) 的值: " + str(log_integral_phi_value))
+    # logger.info("log(λ(δ)): {}, log(∫ φ(k, ξ) dξ): {}".format(log_lambda_hat_value, log_integral_phi_value))