predictor.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import time
  2. import utils
  3. import numpy as np
  4. class Predictor:
  5. '''
  6. reference
  7. '''
  8. def __init__(self, ref_name = ["Unknown Market"], alpha = [1.0 for _ in range(99)], gamma = 0.999):
  9. self.loop = 0
  10. self.arr = []
  11. self.trade_mp_series = []
  12. self.ref_mp_series = []
  13. self.ref_num = len(ref_name)
  14. ### 定价
  15. self.window = 10
  16. ### 价格系数
  17. self.alpha = alpha
  18. # 参考
  19. print('定价系数:', gamma)
  20. self.gamma = gamma
  21. self.avg_spread = [None for _ in range(self.ref_num)]
  22. def processer(self):
  23. '''
  24. 计算任务
  25. '''
  26. # update trade mp
  27. bp=self.arr[-1][utils.BP_INDEX]
  28. ap=self.arr[-1][utils.AP_INDEX]
  29. mp=(bp+ap)*0.5
  30. self.trade_mp_series.append(mp)
  31. # 更新参考盘口mp
  32. ref_mp = []
  33. for i in range(self.ref_num):
  34. bp = self.arr[-1][utils.LEN*(1+i)+utils.BP_INDEX]
  35. ap = self.arr[-1][utils.LEN*(1+i)+utils.AP_INDEX]
  36. mp=(bp+ap)*0.5
  37. ref_mp.append(mp)
  38. self.ref_mp_series.append(ref_mp)
  39. # 偏差计算
  40. self._update_avg_spread()
  41. def _update_avg_spread(self):
  42. '''
  43. 更新平均偏差
  44. '''
  45. # 计算偏差1
  46. for i in range(self.ref_num):
  47. bias = self.ref_mp_series[-1][i]*self.alpha[i] - self.trade_mp_series[-1]
  48. # 如果是刚启动 gamma不能太大
  49. if self.loop < 100:
  50. gamma = 0.9
  51. else:
  52. gamma = self.gamma
  53. if self.avg_spread[i] == None:
  54. self.avg_spread[i] = bias
  55. else:
  56. self.avg_spread[i] = self.avg_spread[i]*gamma + bias*(1-gamma)
  57. def check_length(self):
  58. # 行情缓存
  59. if len(self.arr) > self.window:del(self.arr[0])
  60. if len(self.trade_mp_series) > self.window:del(self.trade_mp_series[0])
  61. if len(self.ref_mp_series) > self.window:del(self.ref_mp_series[0])
  62. # @utils.timeit
  63. def onTime(self, data):
  64. if isinstance(data, list):
  65. if len(data) > 0:
  66. self.loop += 1
  67. self.arr.append(data)
  68. self.processer()
  69. self.check_length()
  70. else:
  71. print("行情数据为空")
  72. else:
  73. print("行情数据为None")
  74. # @utils.timeit
  75. def Get_ref(self, ref_ticker):
  76. '''
  77. get ref price
  78. '''
  79. ref_mid = []
  80. for i in range(self.ref_num):
  81. ref_mid.append(
  82. [ref_ticker[i][0]*self.alpha[i] - self.avg_spread[i], ref_ticker[i][1]*self.alpha[i] - self.avg_spread[i]]
  83. )
  84. return ref_mid
  85. if __name__ == "__main__":
  86. import pandas as pd
  87. import numpy as np
  88. import matplotlib
  89. matplotlib.use('TkAgg')
  90. import matplotlib.pyplot as plt
  91. arr = pd.read_csv('history/ftm_usdt_binance_usdt_swap.csv').values.tolist()
  92. def line_data_to_tickers(data, ref_num):
  93. ref_tickers = []
  94. for i in range(ref_num):
  95. ref_tickers.append([data[utils.LEN*(i+1)+utils.BP_INDEX], data[utils.LEN*(i+1)+utils.AP_INDEX]])
  96. return ref_tickers
  97. ref_num = len(arr[0])//utils.LEN - 1
  98. p = Predictor(ref_name=["unkwon" for _ in range(ref_num)])
  99. t = []
  100. ref_index = 1
  101. for data in arr:
  102. p.onTime(data)
  103. trade_mp = (data[utils.BP_INDEX] + data[utils.AP_INDEX])*0.5
  104. ref_price = p.Get_ref(line_data_to_tickers(data, ref_num))
  105. t.append([
  106. trade_mp,
  107. (ref_price[ref_index][0]+ref_price[ref_index][1])*0.5,
  108. ])
  109. t = pd.DataFrame(t,columns=['mp','ref'])
  110. if 1:
  111. plt.figure()
  112. plt.plot(t['mp'],'k')
  113. plt.plot(t['ref'],'g')
  114. plt.grid()
  115. plt.show()