ソースを参照

预测器的基本部分都弄好了

skyfffire 2 年 前
コミット
3d661f59e8
1 ファイル変更158 行追加0 行削除
  1. 158 0
      strategy/src/predictor.rs

+ 158 - 0
strategy/src/predictor.rs

@@ -0,0 +1,158 @@
+use rust_decimal::prelude::*;
+use rust_decimal_macros::dec;
+use standard::Ticker;
+use crate::utils;
+
+#[derive(Debug)]
+pub struct Predictor {
+    pub loop_count: i64,                                                // 统计当前预测器更新的次数,loop
+    pub market_info_list: Vec<Vec<Decimal>>,                            // TODO 这里存放的是一个市场数据汇总,后面有时间再优化,arr
+    pub mid_price_list: Vec<Decimal>,                                   // 中间价的数组,trade_mp_series
+    pub ref_mid_price_per_exchange_per_frame: Vec<Vec<Decimal>>,        // 参考交易所的中间价的数组(注意是二维的,因为参考交易所有多个),ref_mp_series
+    pub ref_exchange_length: usize,                                     // 参考交易所数量,ref_num
+    pub data_length_max: usize,                                         // 各类数据极限长度,防止爆内存,window
+    pub alpha: Vec<Decimal>,                                            // 价格系数
+    pub gamma: Decimal,                                                 // 定价系数
+    pub avg_spread_list: Vec<Decimal>,                                  // 平均价差
+}
+
+/*
+    使用Builder设计模式创建价格预测器,可以有效提高代码整洁度
+    下面的单元测试有使用示例
+*/
+impl Predictor {
+    pub fn new(ref_exchange_length: usize) -> Self {
+        Self {
+            loop_count: 0,
+            market_info_list: vec![],
+            mid_price_list: vec![],
+            ref_mid_price_per_exchange_per_frame: vec![],
+            ref_exchange_length,
+            data_length_max: 10,
+            alpha: vec![Decimal::new(1, 0); 100],
+            gamma: Decimal::from_f64(0.999).unwrap(),
+            avg_spread_list: vec![dec!(0); ref_exchange_length],
+        }
+    }
+
+    pub fn alpha(mut self, alpha: Vec<Decimal>) -> Self {
+        self.alpha = alpha;
+        self
+    }
+
+    pub fn gamma(mut self, gamma: Decimal) -> Self {
+        self.gamma = gamma;
+        self
+    }
+
+    // 计算任务,python里写作processer,是个错误的单词
+    pub fn processor(mut self) {
+        let last_market_info = self.market_info_list.last().unwrap();
+
+        // 更新mid_price
+        let bid_price = last_market_info[utils::BID_PRICE_INDEX];
+        let ask_price = last_market_info[utils::BID_PRICE_INDEX];
+        let mid_price = (bid_price + ask_price) * dec!(0.5);
+        self.mid_price_list.push(mid_price);
+
+        // 更新参考ref_mid_price
+        let mut ref_mid_price_per_exchange = vec![];
+        for ref_index in 0..self.ref_exchange_length {
+            let ref_bid_price = last_market_info[utils::LENGTH*(1+ref_index)+utils::BID_PRICE_INDEX];
+            let ref_ask_price = last_market_info[utils::LENGTH*(1+ref_index)+utils::ASK_PRICE_INDEX];
+            let ref_mid_price = (ref_bid_price + ref_ask_price) * dec!(0.5);
+            // 依照交易所次序添加到ref_mid_price_per_exchange中
+            ref_mid_price_per_exchange.push(ref_mid_price);
+        }
+        self.ref_mid_price_per_exchange_per_frame.push(ref_mid_price_per_exchange);
+
+        // 价差更新
+        self.update_avg_spread()
+    }
+
+    // 更新平均价差,_update_avg_spread
+    fn update_avg_spread(mut self) {
+        let last_ref_mid_price_per_exchange = self.ref_mid_price_per_exchange_per_frame.last().unwrap();
+        let mid_price_last = self.mid_price_list.last().unwrap();
+
+        for ref_index in 0..self.ref_exchange_length {
+            let bias = last_ref_mid_price_per_exchange[ref_index] * self.alpha[ref_index] - mid_price_last;
+
+            let mut gamma = self.gamma;
+            // 如果程序刚刚启动,gamma值不能太大
+            if self.loop_count < 100 {
+                gamma = dec!(0.9);
+            }
+
+            // 检测是否初始化
+            if dec!(0).eq(&self.avg_spread_list[ref_index]) {
+                self.avg_spread_list[ref_index] = bias;
+            } else {
+                self.avg_spread_list[ref_index] = self.avg_spread_list[ref_index] * gamma + bias*(dec!(1)-gamma);
+            }
+        }
+    }
+
+    // 长度限定
+    fn check_length(mut self) {
+        // 市场汇总信息长度限定
+        if self.market_info_list.len() > self.data_length_max {
+            self.market_info_list.remove(0);
+        }
+        // 交易交易所的mid_price长度限定
+        if self.mid_price_list.len() > self.data_length_max {
+            self.mid_price_list.remove(0);
+        }
+        // 参考交易所的长度限定
+        if self.ref_mid_price_per_exchange_per_frame.len() > self.data_length_max {
+            self.ref_mid_price_per_exchange_per_frame.remove(0);
+        }
+    }
+
+    // 市场信息处理器,也是python里的onTime方法
+    fn market_info_handler(mut self, new_market_info: Vec<Decimal>) {
+        // 异常行情信息不处理
+        if new_market_info == None {
+            return;
+        }
+
+        // 空数据不处理
+        if new_market_info.len() == 0 {
+            return;
+        }
+
+        self.loop_count += 1;
+        self.market_info_list.push(new_market_info);
+        self.processor();
+        self.check_length();
+    }
+
+    // fn get_ref_price(mut self, ref_ticker_list: Vec<Ticker>) {
+    //
+    // }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::io;
+    use std::io::Write;
+    use rust_decimal::Decimal;
+    use crate::predictor::Predictor;
+
+    #[test]
+    fn predictor_build_test() {
+        let mut stdout = io::stdout();
+
+        let predictor1 = Predictor::new(10)
+            .alpha(vec![Decimal::new(99, 2); 100])
+            .gamma(Decimal::new(8, 1));
+        writeln!(stdout, "predictor1:").unwrap();
+        writeln!(stdout, "{:?}", predictor1).unwrap();
+        writeln!(stdout, "").unwrap();
+
+        let predictor2 = Predictor::new(10);
+        writeln!(stdout, "predictor2:").unwrap();
+        writeln!(stdout, "{:?}", predictor2).unwrap();
+        writeln!(stdout, "").unwrap();
+    }
+}