predictor.rs 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. use std::collections::BTreeMap;
  2. use rust_decimal::prelude::*;
  3. use rust_decimal_macros::dec;
  4. use tracing::{instrument};
  5. use standard::Ticker;
  6. use global::public_params;
  7. #[derive(Debug)]
  8. pub struct Predictor {
  9. pub loop_count: i64, // 统计当前预测器更新的次数,loop
  10. pub market_info_list: Vec<Vec<Decimal>>, // TODO 这里存放的是一个市场数据汇总,后面有时间再优化,arr
  11. pub mid_price_list: Vec<Decimal>, // 中间价的数组,trade_mp_series
  12. pub ref_mid_price_per_exchange_per_frame: Vec<Vec<Decimal>>, // 参考交易所的中间价的数组(注意是二维的,因为参考交易所有多个),ref_mp_series
  13. pub ref_exchange_length: usize, // 参考交易所数量,ref_num
  14. pub data_length_max: usize, // 各类数据极限长度,防止爆内存,window
  15. pub alpha: Vec<Decimal>, // 价格系数
  16. pub gamma: Decimal, // 定价系数
  17. pub avg_spread_list: Vec<Decimal>, // 平均价差
  18. }
  19. /*
  20. 使用Builder设计模式创建价格预测器,可以有效提高代码整洁度
  21. 下面的单元测试有使用示例
  22. */
  23. impl Predictor {
  24. pub fn new(ref_exchange_length: usize) -> Self {
  25. Self {
  26. loop_count: 0,
  27. market_info_list: vec![],
  28. mid_price_list: vec![],
  29. ref_mid_price_per_exchange_per_frame: vec![],
  30. ref_exchange_length,
  31. data_length_max: 10,
  32. alpha: vec![Decimal::new(1, 0); 100],
  33. gamma: Decimal::from_f64(0.999).unwrap(),
  34. avg_spread_list: vec![dec!(0); ref_exchange_length],
  35. }
  36. }
  37. pub fn alpha(mut self, alpha: Vec<Decimal>) -> Self {
  38. self.alpha = alpha;
  39. self
  40. }
  41. pub fn gamma(mut self, gamma: Decimal) -> Self {
  42. self.gamma = gamma;
  43. self
  44. }
  45. // 计算任务,python里写作processer,是个错误的单词
  46. #[instrument(skip(self), level="TRACE")]
  47. fn processor(&mut self) {
  48. let last_market_info = self.market_info_list.last().unwrap();
  49. // 更新mid_price
  50. let bid_price = last_market_info[public_params::BID_PRICE_INDEX];
  51. let ask_price = last_market_info[public_params::ASK_PRICE_INDEX];
  52. let mid_price = (bid_price + ask_price) * dec!(0.5);
  53. self.mid_price_list.push(mid_price);
  54. // 更新参考ref_mid_price
  55. let mut ref_mid_price_per_exchange = vec![];
  56. for ref_index in 0..self.ref_exchange_length {
  57. let ref_bid_price = last_market_info[public_params::LENGTH*(1+ref_index)+public_params::BID_PRICE_INDEX];
  58. let ref_ask_price = last_market_info[public_params::LENGTH*(1+ref_index)+public_params::ASK_PRICE_INDEX];
  59. let ref_mid_price = (ref_bid_price + ref_ask_price) * dec!(0.5);
  60. // 依照交易所次序添加到ref_mid_price_per_exchange中
  61. ref_mid_price_per_exchange.push(ref_mid_price);
  62. }
  63. self.ref_mid_price_per_exchange_per_frame.push(ref_mid_price_per_exchange);
  64. // 价差更新
  65. (*self).update_avg_spread()
  66. }
  67. // 更新平均价差,_update_avg_spread
  68. #[instrument(skip(self), level="TRACE")]
  69. fn update_avg_spread(&mut self) {
  70. let last_ref_mid_price_per_exchange = self.ref_mid_price_per_exchange_per_frame.last().unwrap();
  71. let mid_price_last = self.mid_price_list.last().unwrap();
  72. for ref_index in 0..self.ref_exchange_length {
  73. let bias = last_ref_mid_price_per_exchange[ref_index] * self.alpha[ref_index] - mid_price_last;
  74. let mut gamma = self.gamma;
  75. // 如果程序刚刚启动,gamma值不能太大
  76. if self.loop_count < 100 {
  77. gamma = dec!(0.9);
  78. }
  79. // 检测是否初始化
  80. if dec!(0).eq(&self.avg_spread_list[ref_index]) {
  81. self.avg_spread_list[ref_index] = bias;
  82. } else {
  83. self.avg_spread_list[ref_index] = self.avg_spread_list[ref_index] * gamma + bias*(dec!(1)-gamma);
  84. }
  85. }
  86. }
  87. // 长度限定
  88. #[instrument(skip(self), level="TRACE")]
  89. fn check_length(&mut self) {
  90. // 市场汇总信息长度限定
  91. if self.market_info_list.len() > self.data_length_max {
  92. self.market_info_list.remove(0);
  93. }
  94. // 交易交易所的mid_price长度限定
  95. if self.mid_price_list.len() > self.data_length_max {
  96. self.mid_price_list.remove(0);
  97. }
  98. // 参考交易所的长度限定
  99. if self.ref_mid_price_per_exchange_per_frame.len() > self.data_length_max {
  100. self.ref_mid_price_per_exchange_per_frame.remove(0);
  101. }
  102. }
  103. // 市场信息处理器,也是python里的onTime方法
  104. #[instrument(skip(self, new_market_info), level="TRACE")]
  105. pub fn market_info_handler(&mut self, new_market_info: &Vec<Decimal>) {
  106. // 空数据不处理
  107. if new_market_info.len() == 0 {
  108. return;
  109. }
  110. if self.loop_count < i64::MAX {
  111. self.loop_count += 1;
  112. }
  113. self.market_info_list.push(new_market_info.clone());
  114. (*self).processor();
  115. (*self).check_length();
  116. }
  117. // 获取预定价格, 也就是python的Get_ref函数
  118. #[instrument(skip(self, ref_ticker_map), level="TRACE")]
  119. pub fn get_ref_price(&mut self, ref_ticker_map: &BTreeMap<String, Ticker>) -> Vec<Vec<Decimal>> {
  120. let mut ref_price_list = vec![];
  121. let ref_exchange_names: Vec<_> = ref_ticker_map.keys().collect();
  122. for ref_index in 0..ref_exchange_names.len() {
  123. let ref_exchange = ref_exchange_names[ref_index];
  124. let ticker = ref_ticker_map.get(ref_exchange).unwrap();
  125. let bid_price = ticker.buy;
  126. let ask_price = ticker.sell;
  127. let ref_bid_price = bid_price * self.alpha[ref_index] - self.avg_spread_list[ref_index];
  128. let ref_ask_price = ask_price * self.alpha[ref_index] - self.avg_spread_list[ref_index];
  129. ref_price_list.push(vec![ref_bid_price, ref_ask_price]);
  130. }
  131. return ref_price_list;
  132. }
  133. }
  134. #[cfg(test)]
  135. mod tests {
  136. use std::collections::BTreeMap;
  137. use std::io;
  138. use std::io::Write;
  139. use rust_decimal_macros::dec;
  140. use standard::Ticker;
  141. use crate::predictor::Predictor;
  142. #[test]
  143. fn predictor_build_test() {
  144. let mut stdout = io::stdout();
  145. let predictor1 = Predictor::new(2)
  146. .alpha(vec![dec!(0.99); 100])
  147. .gamma(dec!(0.8));
  148. writeln!(stdout, "predictor1:").unwrap();
  149. writeln!(stdout, "{:?}", predictor1).unwrap();
  150. writeln!(stdout, "").unwrap();
  151. let predictor2 = Predictor::new(2);
  152. writeln!(stdout, "predictor2:").unwrap();
  153. writeln!(stdout, "{:?}", predictor2).unwrap();
  154. writeln!(stdout, "").unwrap();
  155. }
  156. #[test]
  157. fn market_info_handler_test() {
  158. let mut predictor = Predictor::new(1);
  159. let market_info_0 = vec![dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.89), dec!(0.79), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79)];
  160. predictor.market_info_handler(&market_info_0);
  161. let market_info_1 = vec![dec!(0.98), dec!(0.99), dec!(0.56), dec!(0.49), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.89), dec!(0.79), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79)];
  162. predictor.market_info_handler(&market_info_1);
  163. }
  164. #[test]
  165. fn get_ref_price_test() {
  166. let mut predictor = Predictor::new(1)
  167. .alpha(vec![dec!(0.99); 100])
  168. .gamma(dec!(0.8));
  169. //
  170. let mut ref_ticker_map: BTreeMap<String, Ticker> = BTreeMap::new();
  171. ref_ticker_map.insert("binance".to_string(), Ticker{
  172. time: 0,
  173. high: Default::default(),
  174. low: Default::default(),
  175. sell: dec!(0.93),
  176. buy: dec!(0.92),
  177. last: Default::default(),
  178. volume: Default::default(),
  179. });
  180. println!("before market info: {:?}", predictor.get_ref_price(&ref_ticker_map));
  181. let mut market_info = vec![];
  182. market_info = vec![dec!(0.99), dec!(1.0), dec!(0.991), dec!(0.79), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.89), dec!(0.79), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79)];
  183. predictor.market_info_handler(&market_info);
  184. println!("market info 0: {:?}", predictor.get_ref_price(&ref_ticker_map));
  185. market_info = vec![dec!(0.98), dec!(0.99), dec!(0.981), dec!(0.49), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.89)];
  186. predictor.market_info_handler(&market_info);
  187. println!("market info 1: {:?}", predictor.get_ref_price(&ref_ticker_map));
  188. market_info = vec![dec!(0.93), dec!(1.0), dec!(0.931), dec!(0.79), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.89)];
  189. predictor.market_info_handler(&market_info);
  190. println!("market info 2: {:?}", predictor.get_ref_price(&ref_ticker_map));
  191. market_info = vec![dec!(0.98), dec!(0.49), dec!(0.981), dec!(0.49), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.89)];
  192. predictor.market_info_handler(&market_info);
  193. println!("market info 3: {:?}", predictor.get_ref_price(&ref_ticker_map));
  194. market_info = vec![dec!(0.99), dec!(1.0), dec!(0.991), dec!(0.69), dec!(0.99), dec!(1.0), dec!(0.89), dec!(0.79), dec!(0.89)];
  195. predictor.market_info_handler(&market_info);
  196. println!("market info 4: {:?}", predictor.get_ref_price(&ref_ticker_map));
  197. market_info = vec![dec!(0.98), dec!(0.969), dec!(0.981), dec!(0.49), dec!(0.99), dec!(1.0), dec!(1.0), dec!(1.0), dec!(0.89)];
  198. predictor.market_info_handler(&market_info);
  199. println!("market info 5: {:?}", predictor.get_ref_price(&ref_ticker_map));
  200. }
  201. }