Selaa lähdekoodia

消息分发 之 best prices

skyfffire 2 viikkoa sitten
vanhempi
commit
cc88cf9ac8
3 muutettua tiedostoa jossa 93 lisäystä ja 50 poistoa
  1. 72 33
      src/data_manager.rs
  2. 12 8
      src/exchange/extended_stream_client.rs
  3. 9 9
      src/main.rs

+ 72 - 33
src/data_manager.rs

@@ -1,12 +1,15 @@
 use std::collections::{HashMap};
+use std::str::FromStr;
 use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
-use anyhow::{Result};
-use tracing::{warn};
+use anyhow::{anyhow, bail, Result};
+use rust_decimal::Decimal;
+use serde_json::Value;
+use tracing::{info, warn};
 use crate::utils::response::Response;
 
 pub struct DataManager {
-    // pub asks_map: HashMap<String, BTreeMap<Decimal, Decimal>>,
-    // pub bids_map: HashMap<String, BTreeMap<Reverse<Decimal>, Decimal>>,
+    pub best_ask: Decimal,
+    pub best_bid: Decimal,
 
     pub delay_total: AtomicI64,
     pub delay_count: AtomicU64,
@@ -14,12 +17,10 @@ pub struct DataManager {
 
 impl DataManager {
     pub fn new() -> Self {
-        // let asks_map: HashMap<String, BTreeMap<Decimal, Decimal>> = HashMap::new();
-        // let bids_map: HashMap<String, BTreeMap<Reverse<Decimal>, Decimal>> = HashMap::new();
-
         DataManager {
-            // asks_map,
-            // bids_map,
+            best_ask: Default::default(),
+            best_bid: Default::default(),
+            
             delay_total: AtomicI64::new(0),
             delay_count: AtomicU64::new(0),
         }
@@ -47,32 +48,70 @@ impl DataManager {
         self.delay_count.store(0, Ordering::Relaxed); // 原子写
     }
     
-    pub async fn dispatch_message(&mut self, _response: &Response) -> Result<()> {
-        // // 1. 预解析为通用的 Value
-        // let v = response.data.clone();
-        // 
-        // info!("准备分发的消息:{}", serde_json::to_string_pretty(&v)?);
+    pub async fn dispatch_message(&mut self, response: &Response) -> Result<()> {
+        // 预解析为通用的 Value
+        let v = response.data.clone();
+        
+        // info!("准备分发的消息:{}, {}", serde_json::to_string_pretty(&v)?, response.label);
+
+        // 2. 获取 topic_info 字段用于路由消息,在该策略中extended可以用label
+        let topic_info = &response.label;
+        
+        // 3. 根据 topic_info 的内容进行分发 (match)
+        if topic_info.contains("ExtendedBestPrices") {
+            self.process_best_prices(&v).await?;            
+        } else if topic_info.contains("spot@public.aggre.depth.v3.api.pb") {
+            
+        } else {
+            // 如果是未知的 topic,返回一个错误
+            bail!("Received a message with an unknown topic_info: {}", topic_info);
+        }
+
+        Ok(())
+    }
+    
+    pub async fn process_best_prices(&mut self, value: &Value) -> Result<()> {
+        // 预先捕获整个 Value 的字符串表示,用于错误报告
+        let value_str = serde_json::to_string(&value).unwrap_or_else(|_| "无法序列化 JSON Value".to_string());
+
+        // 尝试获取 data 字段
+        let data = value.get("data")
+            .ok_or_else(|| anyhow!("获取 'data' 字段失败,原始 JSON: {}", value_str))?;
 
-        // 2. 获取 topic_info 字段用于路由
-        // and_then 确保了 get 返回 Some 时才调用 as_str
-        // context 在任何一步失败时提供错误信息 (字段不存在,或不是字符串)
-        // let topic_info = v
-        //     .get("topic_info")
-        //     .and_then(Value::as_str)
-        //     .context("Message is missing 'topic_info' field or it's not a string")?;
-        //
-        // // 3. 根据 topic_info 的内容进行分发 (match)
-        // if topic_info.contains("spot@public.kline.v3.api.pb") {
-        //     // 如果是K线数据,调用 process_kline
-        //     self.process_klines(&v).await?;
-        // } else if topic_info.contains("spot@public.aggre.depth.v3.api.pb") {
-        //     // 如果是增量深度数据,调用 process_depth_update
-        //     self.process_depth_update(&v).await?;
-        // } else {
-        //     // 如果是未知的 topic,返回一个错误
-        //     bail!("Received a message with an unknown topic_info: {}", topic_info);
-        // }
+        // 尝试从 data 中获取 "a" (asks) 数组
+        let asks_array = data.get("a")
+            .and_then(|v| v.as_array()) // and_then 链式调用,确保只有当 v 存在且是数组时才继续
+            .ok_or_else(|| anyhow!("获取 'data.a' 数组失败,原始 JSON: {}", value_str))?;
 
+        // 尝试从 data 中获取 "b" (bids) 数组
+        let bids_array = data.get("b")
+            .and_then(|v| v.as_array())
+            .ok_or_else(|| anyhow!("获取 'data.b' 数组失败,原始 JSON: {}", value_str))?;
+
+        // 如若有发送asks信息
+        if asks_array.len() > 0 {
+            let ask_item = &asks_array[0];
+            let p = ask_item.get("p")
+                .and_then(|v| v.as_str())
+                .ok_or_else(|| anyhow!("获取 'data.a.p' 字符串失败,原始 JSON: {}", value_str))?;
+            
+            self.best_ask = Decimal::from_str(p)
+                .map_err(|e| anyhow!("将价格字符串 '{}' 解析为 Decimal 失败: {},原始 JSON: {}", p, e, value_str))?;
+        }
+
+        // 如若有发送bids信息
+        if bids_array.len() > 0 {
+            let bid_item = &bids_array[0];
+            let p = bid_item.get("p")
+                .and_then(|v| v.as_str())
+                .ok_or_else(|| anyhow!("获取 'data.b.p' 字符串失败,原始 JSON: {}", value_str))?;
+
+            self.best_bid = Decimal::from_str(p)
+                .map_err(|e| anyhow!("将价格字符串 '{}' 解析为 Decimal 失败: {},原始 JSON: {}", p, e, value_str))?;
+        }
+        
+        info!("{}, {}", self.best_ask, self.best_bid);
+        
         Ok(())
     }
 }

+ 12 - 8
src/exchange/extended_stream_client.rs

@@ -18,8 +18,8 @@ use crate::utils::stream_utils::{StreamUtils, HeartbeatType};
 #[derive(Clone)]
 #[allow(dead_code)]
 pub struct ExtendedStreamClient {
-    // 标签
-    tag: String,
+    // 数据标签
+    label: String,
     // 地址
     address_url: String,
     // 账号
@@ -30,7 +30,7 @@ pub struct ExtendedStreamClient {
 
 impl ExtendedStreamClient {
     // ============================================= 构造函数 ================================================
-    fn new(tag: String, account_option: Option<ExtendedAccount>, subscribe_pattern: String, is_testnet: bool) -> ExtendedStreamClient {
+    fn new(label: String, account_option: Option<ExtendedAccount>, subscribe_pattern: String, is_testnet: bool) -> ExtendedStreamClient {
         let host = match is_testnet {
             true => {
                 "wss://api.starknet.sepolia.extended.exchange/stream.extended.exchange/v1/".to_string()  // testnet
@@ -43,7 +43,7 @@ impl ExtendedStreamClient {
         let address_url = format!("{}{}", host, subscribe_pattern);
 
         ExtendedStreamClient {
-            tag,
+            label,
             address_url,
             account_option,
             heartbeat_time: 1000 * 10,
@@ -51,8 +51,12 @@ impl ExtendedStreamClient {
     }
 
     // ============================================= 订阅函数 ================================================
-    pub fn order_books(tag: String, account_option: Option<ExtendedAccount>, symbol: String, is_testnet: bool) -> ExtendedStreamClient {
-        Self::new(tag, account_option, format!("orderbooks/{}", symbol), is_testnet)
+    pub fn order_books(label: String, account_option: Option<ExtendedAccount>, symbol: String, is_testnet: bool) -> ExtendedStreamClient {
+        Self::new(label, account_option, format!("orderbooks/{}", symbol), is_testnet)
+    }
+    
+    pub fn best_prices(label: String, account_option: Option<ExtendedAccount>, symbol: String, is_testnet: bool) -> ExtendedStreamClient {
+        Self::new(label, account_option, format!("orderbooks/{}?depth=1", symbol), is_testnet)
     }
 
     // 链接
@@ -66,7 +70,7 @@ impl ExtendedStreamClient {
             Future: std::future::Future<Output=()> + Send + 'static, // 确保 Fut 是一个 Future,且输出类型为 ()
     {
         let address_url = self.address_url.clone();
-        let tag = self.tag.clone();
+        let label = self.label.clone();
 
         // 自动心跳包
         let write_tx_clone1 = write_tx_am.clone();
@@ -115,7 +119,7 @@ impl ExtendedStreamClient {
 
                 trace!("Extended_usdt_swap socket 连接中……");
                 StreamUtils::ws_connect_async(is_shutdown_arc.clone(), handle_function.clone(), request,
-                                              false, tag.clone(), vec![], write_to_socket_rx_arc.clone(),
+                                              false, label.clone(), vec![], write_to_socket_rx_arc.clone(),
                                               Self::message_text, Self::message_ping, Self::message_pong, Self::message_binary).await;
 
                 warn!("Extended_usdt_swap socket 断连,1s以后重连……");

+ 9 - 9
src/main.rs

@@ -3,7 +3,7 @@ mod exchange;
 mod strategy;
 mod data_manager;
 
-use anyhow::Result;
+use anyhow::{anyhow, Result};
 use std::sync::Arc;
 use std::sync::atomic::{AtomicBool, Ordering};
 use backtrace::Backtrace;
@@ -91,10 +91,11 @@ async fn main() {
 ///
 /// # Returns
 pub async fn run_extended_subscriptions(running: Arc<AtomicBool>) -> Result<()> {
-    let is_testnet = true;
-    
+    let is_testnet = false;
+    let symbol = "BTC-USD".to_string();
+
     let stream_client_list = vec![
-        ExtendedStreamClient::order_books("ExtendedOrderBooks".to_string(), None, "BTC-USD".to_string(), is_testnet)
+        ExtendedStreamClient::best_prices(format!("ExtendedBestPrices_{}", symbol), None, symbol, is_testnet)
     ];
 
     // 数据管理及消息分发
@@ -110,11 +111,8 @@ pub async fn run_extended_subscriptions(running: Arc<AtomicBool>) -> Result<()>
         let fun = move |response: Response| {
             if response.code != 200 {
                 error!("出现错误代码:{}", serde_json::to_string_pretty(&response.data).unwrap());
-
-                panic!("出现错误代码:{}", serde_json::to_string_pretty(&response.data).unwrap());
             }
 
-            // info!("{}", serde_json::to_string_pretty(&response.data).unwrap());
             let dm_clone = Arc::clone(&dm);
             async move {
                 let mut dm_guard = dm_clone.lock().await;
@@ -122,8 +120,10 @@ pub async fn run_extended_subscriptions(running: Arc<AtomicBool>) -> Result<()>
                 // 记录消息延迟
                 dm_guard.record_latency(response.received_time, response.reach_time);
 
-                // 交给消息分发函数
-                dm_guard.dispatch_message(&response).await.unwrap();
+                // 交给消息分发函数,并在此处消费掉错误消息
+                if let Err(e) = dm_guard.dispatch_message(&response).await {
+                    warn!("消息分发过程中出现错误: {}", e);
+                }
             }
         };