predictor.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import time
  2. import utils
  3. import numpy as np
  4. from loguru import logger
  5. class Predictor:
  6. '''
  7. reference
  8. '''
  9. def __init__(self, ref_name = ["Unknown Market"], alpha = [1.0 for _ in range(99)], gamma = 0.999):
  10. self.loop = 0
  11. self.arr = []
  12. self.trade_mp_series = []
  13. self.ref_mp_series = []
  14. self.ref_num = len(ref_name)
  15. ### 定价
  16. self.window = 10
  17. ### 价格系数
  18. self.alpha = alpha
  19. # 参考
  20. print('定价系数:', gamma)
  21. self.gamma = gamma
  22. self.avg_spread = [None for _ in range(self.ref_num)]
  23. def processer(self):
  24. '''
  25. 计算任务
  26. '''
  27. # update trade mp
  28. bp=self.arr[-1][utils.BP_INDEX]
  29. ap=self.arr[-1][utils.AP_INDEX]
  30. mp=(bp+ap)*0.5
  31. self.trade_mp_series.append(mp)
  32. # 更新参考盘口mp
  33. ref_mp = []
  34. for i in range(self.ref_num):
  35. bp = self.arr[-1][utils.LEN*(1+i)+utils.BP_INDEX]
  36. ap = self.arr[-1][utils.LEN*(1+i)+utils.AP_INDEX]
  37. mp=(bp+ap)*0.5
  38. ref_mp.append(mp)
  39. self.ref_mp_series.append(ref_mp)
  40. # 偏差计算
  41. self._update_avg_spread()
  42. def _update_avg_spread(self):
  43. '''
  44. 更新平均偏差
  45. '''
  46. # 计算偏差1
  47. for i in range(self.ref_num):
  48. bias = self.ref_mp_series[-1][i]*self.alpha[i] - self.trade_mp_series[-1]
  49. # 如果是刚启动 gamma不能太大
  50. if self.loop < 100:
  51. gamma = 0.9
  52. else:
  53. gamma = self.gamma
  54. if self.avg_spread[i] == None:
  55. self.avg_spread[i] = bias
  56. else:
  57. self.avg_spread[i] = self.avg_spread[i]*gamma + bias*(1-gamma)
  58. def check_length(self):
  59. # 行情缓存
  60. if len(self.arr) > self.window:del(self.arr[0])
  61. if len(self.trade_mp_series) > self.window:del(self.trade_mp_series[0])
  62. if len(self.ref_mp_series) > self.window:del(self.ref_mp_series[0])
  63. # @utils.timeit
  64. def onTime(self, data):
  65. if isinstance(data, list):
  66. if len(data) > 0:
  67. self.loop += 1
  68. self.arr.append(data)
  69. self.processer()
  70. self.check_length()
  71. else:
  72. print("行情数据为空")
  73. else:
  74. print("行情数据为None")
  75. # @utils.timeit
  76. def Get_ref(self, ref_ticker):
  77. '''
  78. get ref price
  79. '''
  80. ref_mid = []
  81. for i in range(self.ref_num):
  82. ref_mid.append(
  83. [ref_ticker[i][0]*self.alpha[i] - self.avg_spread[i], ref_ticker[i][1]*self.alpha[i] - self.avg_spread[i]]
  84. )
  85. return ref_mid
  86. if __name__ == "__main__":
  87. import pandas as pd
  88. import numpy as np
  89. import matplotlib
  90. matplotlib.use('TkAgg')
  91. import matplotlib.pyplot as plt
  92. arr = pd.read_csv('history/ftm_usdt_binance_usdt_swap.csv').values.tolist()
  93. def line_data_to_tickers(data, ref_num):
  94. ref_tickers = []
  95. for i in range(ref_num):
  96. ref_tickers.append([data[utils.LEN*(i+1)+utils.BP_INDEX], data[utils.LEN*(i+1)+utils.AP_INDEX]])
  97. return ref_tickers
  98. ref_num = len(arr[0])//utils.LEN - 1
  99. p = Predictor(ref_name=["unkwon" for _ in range(ref_num)])
  100. t = []
  101. ref_index = 1
  102. for data in arr:
  103. p.onTime(data)
  104. trade_mp = (data[utils.BP_INDEX] + data[utils.AP_INDEX])*0.5
  105. ref_price = p.Get_ref(line_data_to_tickers(data, ref_num))
  106. t.append([
  107. trade_mp,
  108. (ref_price[ref_index][0]+ref_price[ref_index][1])*0.5,
  109. ])
  110. t = pd.DataFrame(t,columns=['mp','ref'])
  111. if 1:
  112. plt.figure()
  113. plt.plot(t['mp'],'k')
  114. plt.plot(t['ref'],'g')
  115. plt.grid()
  116. plt.show()