浏览代码

准备获取k线数据了

skyffire 1 年之前
父节点
当前提交
e15d09f51a
共有 6 个文件被更改,包括 290 次插入12 次删除
  1. 9 2
      README.MD
  2. 74 3
      src/db_connector.rs
  3. 1 0
      src/main.rs
  4. 20 0
      src/params_utils.rs
  5. 43 7
      src/server.rs
  6. 143 0
      src/symbol_filter.rs

+ 9 - 2
README.MD

@@ -97,8 +97,15 @@
     "data": [
         {
             "symbol": "BTC_USDT",           // 交易对
-            "rise": 3,                      // 涨跌幅,3代表3%
-            "volume": 0.1,                  // 交易量,单位是M(百万)
+            "rise": {                       // 涨跌幅,3代表3%
+                "gate_usdt_swap": 3,
+                "bitget_usdt_swap": 2.9,
+            },
+            "volume": {                     // 交易量,单位是M(百万)
+                "gate_usdt_swap": 0.1,
+                "bitget_usdt_swap": 0.5,
+                "total": 0.6,               // 所有交易量之和
+            },
         },
         ...
     ],

+ 74 - 3
src/db_connector.rs

@@ -36,10 +36,71 @@ pub async fn get_trades_json(exchange: &str, symbol: &str, start_at: i64, end_at
         serde_json::from_str(response_text.as_str()).unwrap()
     } else {
         Response {
-            msg: Some("请求失败,预计是指标层的网络请求错误。".to_string()),
-            query: Default::default(),
+            msg: Some("get_trades_json 请求失败,预计是指标层的网络请求错误。".to_string()),
+            query: params,
             data: Default::default(),
-            code: -200,
+            code: 500,
+        }
+    }
+}
+
+pub async fn get_records_json(exchange: &str, symbol: &str, start_at: i64, end_at: i64) -> Response {
+    let url = "http://dc.skyfffire.com:8888/records";
+    let params = json!({
+        "exchange": exchange,
+        "symbol": symbol,
+        "start_time": start_at,
+        "end_time": end_at
+    });
+
+    // 创建 HTTP 客户端
+    let client = Client::new();
+
+    // 发送 GET 请求
+    let response = client.get(url)
+        .query(&params)
+        .send()
+        .await.unwrap();
+
+    // 错误处理
+    if response.status().is_success() {
+        let response_text = response.text().await.unwrap();
+        serde_json::from_str(response_text.as_str()).unwrap()
+    } else {
+        Response {
+            msg: Some("get_records_json 请求失败,预计是指标层的网络请求错误。".to_string()),
+            query: params,
+            data: Default::default(),
+            code: 500,
+        }
+    }
+}
+
+pub async fn get_symbols_json(exchange: &str) -> Response {
+    let url = "http://dc.skyfffire.com:8888/symbols";
+    let params = json!({
+        "exchange": exchange,
+    });
+
+    // 创建 HTTP 客户端
+    let client = Client::new();
+
+    // 发送 GET 请求
+    let response = client.get(url)
+        .query(&params)
+        .send()
+        .await.unwrap();
+
+    // 错误处理
+    if response.status().is_success() {
+        let response_text = response.text().await.unwrap();
+        serde_json::from_str(response_text.as_str()).unwrap()
+    } else {
+        Response {
+            msg: Some("get_symbols_json 请求失败,预计是指标层的网络请求错误。".to_string()),
+            query: params.clone(),
+            data: Default::default(),
+            code: 500,
         }
     }
 }
@@ -53,3 +114,13 @@ async fn get_trades_test() {
     let rst = get_trades_json("bitget_usdt_swap", "BTC_USDT", 1713210360000, 1713210960000).await;
     info!(?rst)
 }
+
+#[tokio::test]
+async fn get_symbols_test() {
+    use global::log_utils::init_log_with_info;
+    use tracing::info;
+    init_log_with_info();
+
+    let rst = get_symbols_json("bitget_usdt_swap").await;
+    info!(?rst)
+}

+ 1 - 0
src/main.rs

@@ -4,6 +4,7 @@ pub mod db_connector;
 mod msv;
 mod trades;
 mod params_utils;
+mod symbol_filter;
 
 use std::sync::Arc;
 use std::sync::atomic::{AtomicBool, Ordering};

+ 20 - 0
src/params_utils.rs

@@ -64,6 +64,26 @@ pub fn get_str(value: Value, param_name: &str) -> Result<String, HttpResponse> {
     }
 }
 
+pub fn get_array(value: Value, param_name: &str) -> Result<Vec<Value>, HttpResponse> {
+    match value[param_name].as_array() {
+        None => {
+            // 返回数据
+            let response = Response {
+                query: value.clone(),
+                msg: Some(format!("{}是【数组】必填项, 你的参数: {}", param_name, value.to_string())),
+                code: 500,
+                data: Value::Null,
+            };
+
+            let json_string = serde_json::to_string(&response).unwrap();
+            Err(HttpResponse::Ok().content_type("application/json").body(json_string))
+        }
+        Some(array) => {
+            Ok(array.clone())
+        }
+    }
+}
+
 #[tokio::test]
 async fn test_decimal_to_i64() {
     use global::log_utils::init_log_with_info;

+ 43 - 7
src/server.rs

@@ -3,10 +3,13 @@ use std::sync::atomic::{AtomicBool, Ordering};
 use actix_cors::Cors;
 use actix_web::{web, App, HttpResponse, HttpServer, Responder, post};
 use rust_decimal::Decimal;
+use rust_decimal_macros::dec;
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
 use tracing::{info};
 use crate::msv::generate_msv;
+use crate::params_utils::{get_array, get_str, parse_str_to_decimal};
+use crate::symbol_filter::get_symbols;
 use crate::trades::generate_trades;
 
 #[derive(Serialize, Deserialize, Debug)]
@@ -28,15 +31,48 @@ pub struct Trade {
 // 句柄 GET 请求
 #[post("/ia/get_symbols_by_filter")]
 async fn get_symbols_by_filter(query: web::Json<Value>) -> impl Responder {
-    let response = Response {
-        query: query.into_inner(),
-        msg: Some("get_symbols_by_filter 这个接口还没做".to_string()),
-        code: 400,
-        data: Value::Null,
+    let query_value = query.clone();
+
+    // 参数处理
+    let mode = match get_str(query_value.clone(), "mode") {
+        Ok(str) => {
+            str
+        }
+        Err(response) => {
+            return response
+        }
+    };
+    let exchanges = match get_array(query_value.clone(), "exchanges") {
+        Ok(rst) => {
+            let mut exchanges = vec![];
+            for exchange_value in rst {
+                exchanges.push(exchange_value.as_str().unwrap().to_string())
+            }
+
+            exchanges
+        }
+        Err(response) => {
+            return response
+        }
+    };
+    let minute_time_range = match parse_str_to_decimal(query_value.clone(), "minute_time_range") {
+        Ok(range) => {
+            range
+        }
+        Err(response) => {
+            return response
+        }
+    };
+    let filters = match get_array(query_value.clone(), "filters") {
+        Ok(filters) => {
+            filters
+        }
+        Err(response) => {
+            vec![]
+        }
     };
 
-    let json_string = serde_json::to_string(&response).unwrap();
-    HttpResponse::BadRequest().content_type("application/json").body(json_string)
+    get_symbols(&mode, &exchanges, &minute_time_range, &filters).await
 }
 
 // ia: intelligence agency, 情报部门

+ 143 - 0
src/symbol_filter.rs

@@ -0,0 +1,143 @@
+use std::collections::{HashMap, HashSet};
+use actix_web::HttpResponse;
+use rust_decimal::Decimal;
+use serde_json::{json, Value};
+use tracing::info;
+use crate::db_connector::{get_records_json, get_symbols_json};
+use crate::msv::{generate_msv_by_trades, parse_json_to_trades};
+use crate::server::Response;
+
+fn get_public_symbols(symbols_map: &Value) -> Vec<String> {
+    // 解析 JSON 数据为 HashMap
+    let exchanges_map: HashMap<String, Vec<String>> = serde_json::from_value(symbols_map.clone()).unwrap();
+
+    // 初始化一个可选的 HashSet 用于存储交集结果
+    let mut common_symbols: Option<HashSet<String>> = None;
+
+    // 遍历每个交易所的 symbols
+    for symbols in exchanges_map.values() {
+        let current_set: HashSet<String> = symbols.into_iter().cloned().collect();
+        common_symbols = match common_symbols {
+            Some(set) => Some(set.intersection(&current_set).cloned().collect()),
+            None => Some(current_set),
+        };
+    }
+
+    // 将结果从 HashSet 转换为 Vec<String>
+    common_symbols.unwrap_or_default().into_iter().collect::<Vec<String>>()
+}
+
+pub async fn get_symbols(mode: &str, exchanges: &Vec<String>, minute_time_range: &Decimal, filters: &Vec<Value>) -> HttpResponse {
+    // 1. 获取所选交易所的所有交易对
+    let mut symbols_map = json!({});
+    for exchange in exchanges {
+        let db_response = get_symbols_json(exchange.as_str()).await;
+
+        // 对数据库返回的数据进行容错处理
+        if db_response.code == 200 {
+            let symbol_array_value = db_response.data;
+
+            symbols_map[exchange] = symbol_array_value.clone();
+        } else {
+            let json_string = serde_json::to_string(&db_response).unwrap();
+            return HttpResponse::Ok().content_type("application/json").body(json_string);
+        }
+    }
+    // 2. 如果是多交易所,则获取所有交易所的交集
+    let mut symbols: Vec<String> = get_public_symbols(&symbols_map);
+    // 确保有足够的元素
+    let n = 20;
+    // 获取前20个元素,如果不足20个,则获取全部
+    let top_symbols = if symbols.len() > n {
+        &symbols[0..n]
+    } else {
+        &symbols[..]
+    };
+    // 如果你需要将这些元素放在新的 Vec 中
+    symbols = top_symbols.to_vec();
+
+    // 3. 获取它们的k线数据,注意时间范围以及保存形式。
+    // {
+    //     "symbol1": {
+    //     "exchange1": [k1,k2,...]
+    //     "exchange2": [k1,k2,...]
+    // },
+    //     "symbol2": {
+    //     "exchange1": [k1,k2,...]
+    //     "exchange2": [k1,k2,...]
+    // },
+    // }
+    let mut records_map = json!({});
+    for symbol in symbols {
+        records_map[symbol.clone()] = json!({});
+        for exchange in exchanges {
+            let db_response = get_records_json(exchange.as_str()).await;
+
+            // 对数据库返回的数据进行容错处理
+            if db_response.code == 200 {
+                let symbol_array_value = db_response.data;
+
+                records_map[symbol.clone()][exchange] = symbol_array_value.clone();
+            } else {
+                let json_string = serde_json::to_string(&db_response).unwrap();
+                return HttpResponse::Ok().content_type("application/json").body(json_string);
+            }
+        }
+    }
+    // 4. 整理完之后,获取指标数据,进行逻辑判断,注意与/或只是指过滤器,如果选择了多交易所,交易所的逻辑部分都是与
+    // 举例:
+    //      选择了gate_usdt_swap与bitget_usdt_swap两个交易所,并且对交易量进行了0.5M的过滤
+    //      两个交易所都有xrp,如果gate与bitget的xrp交易量都大于0.5M,则会显示在最终列表中
+    // 5. 最终出来的数据结构,要体现出交易所
+
+    let response_value = json!([
+        {
+            "symbol": "BTC_USDT",
+            "rise": {
+                "gate_usdt_swap": 3,
+                "bitget_usdt_swap": 2.9,
+            },
+            "volume": {
+                "gate_usdt_swap": 0.1,
+                "bitget_usdt_swap": 0.5,
+                "total": 0.6,
+            },
+        },
+        {
+            "symbol": "DOGE_USDT",
+            "rise": {
+                "gate_usdt_swap": -1.9,
+                "bitget_usdt_swap": -2,
+            },
+            "volume": {
+                "gate_usdt_swap": 0.2,
+                "bitget_usdt_swap": 0.3,
+                "total": 0.5,
+            },
+        }
+    ]);
+    let response = Response {
+        query: Value::Null,
+        msg: Some("指标生成完毕".to_string()),
+        code: 200,
+        data: response_value,
+    };
+
+    let json_string = serde_json::to_string(&response).unwrap();
+    HttpResponse::Ok().content_type("application/json").body(json_string)
+}
+
+#[tokio::test]
+async fn get_public_symbols_test() {
+    use global::log_utils::init_log_with_info;
+    use tracing::info;
+    init_log_with_info();
+
+    let exchanges = json!({
+        "bitget_usdt_swap": ["BTC_USDT", "ETH_USDT", "MEW_USDT"],
+        "gate_usdt_swap": ["BTC_USDT", "ETH_USDT", "DOGE_USDT"]
+    });
+
+    let rst = get_public_symbols(&exchanges);
+    info!(?rst)
+}