Переглянути джерело

指标能生成了,允许跨域。

skyffire 1 рік тому
батько
коміт
372e075eba
4 змінених файлів з 141 додано та 1 видалено
  1. 1 0
      Cargo.toml
  2. 1 0
      src/main.rs
  3. 123 0
      src/msv_generate.rs
  4. 16 1
      src/server.rs

+ 1 - 0
Cargo.toml

@@ -26,5 +26,6 @@ ctrlc = "3.2.5"
 
 actix-rt = "2.5.0"
 actix-web = "4.0.0-beta.12"
+actix-cors = "0.6"
 
 reqwest = { version = "0.11", features = ["json"] }

+ 1 - 0
src/main.rs

@@ -1,6 +1,7 @@
 mod control_c;
 mod server;
 pub mod db_connector;
+mod msv_generate;
 
 use std::sync::Arc;
 use std::sync::atomic::{AtomicBool, Ordering};

+ 123 - 0
src/msv_generate.rs

@@ -0,0 +1,123 @@
+use std::cmp::max;
+use std::collections::BTreeMap;
+use std::str::FromStr;
+use rust_decimal::Decimal;
+use rust_decimal_macros::dec;
+use serde_json::{json, Value};
+use crate::server::Trade;
+
+// 将trades_json转换为指标
+pub fn generate_msv(trades_json: Value) -> Value {
+    let trades = parse_json_to_trades(trades_json);
+
+    generate_msv_by_trades(trades)
+}
+
+// 将trades转换为具体指标
+pub fn generate_msv_by_trades(mut trades: Vec<Trade>) -> Value {
+    let mut amplitude_map: BTreeMap<Decimal, Decimal> = BTreeMap::new();
+
+    // 每一个元素都遍历一遍
+    trades.reverse();
+    for (index, trade) in trades.iter().enumerate() {
+        // 该元素向前遍历range毫秒
+        let mut range_index = if index == 0 {
+            0
+        } else {
+            index
+        };
+
+        // 寻找区间最大值、最小值
+        let mut max_price = dec!(-1);
+        let mut min_price = dec!(1e28);
+
+        loop {
+            // 第0个就不搞
+            if range_index == 0 {
+                break;
+            }
+
+            let flag_trade = trades.get(range_index).unwrap();
+            let range_time = trade.time - flag_trade.time;
+            // 判断该ticker是否是range ms以外
+            if range_time > dec!(50) {
+                break;
+            }
+
+            // 判断最大值、最小值
+            if flag_trade.price > max_price {
+                max_price = flag_trade.price;
+            }
+            if flag_trade.price < min_price {
+                min_price = flag_trade.price;
+            }
+
+            range_index -= 1;
+        }
+
+        // 逻辑计算层
+        // 取离当前点最远的点进行测量
+        let last_price = trade.price;
+
+        // 不是初始值,以及不是0波动
+        if index != 0 {
+            let mut up_rate = Decimal::ONE_HUNDRED * (last_price - min_price) / min_price;
+            let mut dn_rate = Decimal::ONE_HUNDRED * (last_price - max_price) / max_price;
+
+            up_rate.rescale(2);
+            dn_rate.rescale(2);
+
+            // 去除小数位之后,可以忽略一些太小的波动,减少图表生成压力
+            if up_rate.eq(&Decimal::ZERO) && dn_rate.eq(&Decimal::ZERO) {
+                continue
+            }
+
+            // 如果已经生成了一个波动,则也要和已生成的波动进行比较
+            let insert_value = if amplitude_map.contains_key(&trade.time) {
+                let origin_rate = amplitude_map.get(&trade.time).unwrap();
+
+                if up_rate > dn_rate.abs() {
+                    max(*origin_rate, up_rate)
+                } else {
+                    max(*origin_rate, dn_rate)
+                }
+            } else {
+                if up_rate > dn_rate.abs() {
+                    up_rate
+                } else {
+                    dn_rate
+                }
+            };
+
+            amplitude_map.insert(trade.time, insert_value);
+        }
+    }
+
+    let x: Vec<Decimal> = amplitude_map.keys().cloned().collect();
+    let y: Vec<Decimal> = amplitude_map.values().cloned().collect();
+    let total_size = trades.len();
+    let result_size = x.len();
+    json!({
+        "x": x,
+        "y": y,
+        "total_size": total_size,
+        "result_size": result_size,
+    })
+}
+
+// 将json转换为trades
+pub fn parse_json_to_trades(trades_json: Value) -> Vec<Trade> {
+    let mut rst = vec![];
+
+    for trade_json in trades_json.as_array().unwrap() {
+        let arr = trade_json.as_array().unwrap();
+        rst.push(Trade {
+            id: arr[0].as_str().unwrap().to_string(),
+            time: Decimal::from_str(arr[1].as_str().unwrap()).unwrap(),
+            size: Decimal::from_str(arr[2].as_str().unwrap()).unwrap(),
+            price: Decimal::from_str(arr[3].as_str().unwrap()).unwrap(),
+        });
+    }
+
+    rst
+}

+ 16 - 1
src/server.rs

@@ -1,5 +1,6 @@
 use std::sync::Arc;
 use std::sync::atomic::{AtomicBool, Ordering};
+use actix_cors::Cors;
 use actix_web::{web, App, HttpResponse, HttpServer, Responder, get};
 use chrono::Utc;
 use rust_decimal::Decimal;
@@ -7,6 +8,7 @@ use serde::{Deserialize, Serialize};
 use serde_json::Value;
 use tracing::{info};
 use crate::db_connector::get_trades_json;
+use crate::msv_generate::generate_msv;
 
 // 定义用于反序列化查询参数的结构体
 #[derive(Serialize, Deserialize, Clone)]
@@ -61,6 +63,7 @@ async fn get_symbols_by_filter() -> impl Responder {
 
 #[get("/get_indicator")]
 async fn get_indicator(query: web::Query<IndicatorQuery>) -> impl Responder {
+    // 客户端传过来的数据校验
     if query.validate() {
         // 链接数据服务器查询数据
         let end_time = Utc::now().timestamp_millis();
@@ -74,11 +77,15 @@ async fn get_indicator(query: web::Query<IndicatorQuery>) -> impl Responder {
 
         // 对数据库返回的数据进行容错处理
         if db_response.code == 200 {
+            // 指标生成
+            let indicator = generate_msv(db_response.data);
+
+            // 返回数据
             let response = Response {
                 query_string: serde_json::to_value(&query.into_inner()).unwrap(),
                 message: Some("指标生成完毕".to_string()),
                 code: 200,
-                data: db_response.data,
+                data: indicator,
             };
 
             let json_string = serde_json::to_string(&response).unwrap();
@@ -106,7 +113,15 @@ pub fn run_server(port: u32, running: Arc<AtomicBool>) {
 
     // 启动server
     let server_fut = HttpServer::new(move || {
+        // 配置 CORS
+        let cors = Cors::permissive()
+            .allow_any_origin()
+            .allow_any_header()
+            .allow_any_method()
+            .max_age(3600); // 设置预检请求的缓存时间
+
         App::new()
+            .wrap(cors)
             .service(get_symbols_by_filter)
             .service(get_indicator)
     })