Sfoglia il codice sorgente

策略逻辑写完,暂未测试。

skyfffire 2 settimane fa
parent
commit
a430a8511d

+ 13 - 4
src/exchange/extended_rest_client.rs

@@ -161,7 +161,7 @@ impl ExtendedRestClient {
         ).await
     }
 
-    pub async fn get_order(&mut self, id: String) -> Response {
+    pub async fn get_order(&mut self, id: &str) -> Response {
         let params = json!({});
 
         self.request("GET",
@@ -172,7 +172,7 @@ impl ExtendedRestClient {
         ).await
     }
 
-    pub async fn cancel_order(&mut self, id: String) -> Response {
+    pub async fn cancel_order(&mut self, id: &str) -> Response {
         let params = json!({});
 
         self.request("DELETE",
@@ -183,7 +183,7 @@ impl ExtendedRestClient {
         ).await
     }
 
-    pub async fn cancel_order_by_external_id(&mut self, external_id: String) -> Response {
+    pub async fn cancel_order_by_external_id(&mut self, external_id: &str) -> Response {
         let params = json!({
             "external_id": external_id,
         });
@@ -557,6 +557,15 @@ mod tests {
         info!("{}", serde_json::to_string_pretty(&response.data).unwrap());
     }
 
+    #[tokio::test]
+    async fn test_get_order() {
+        let _guard = setup_logging().unwrap();
+        let mut client = get_client().await;
+        let response = client.get_order("1978656082822787072").await;
+
+        info!("{}", serde_json::to_string_pretty(&response.data).unwrap());
+    }
+
     #[tokio::test]
     async fn test_create_order() {
         let _guard = setup_logging().unwrap();
@@ -577,7 +586,7 @@ mod tests {
     async fn test_cancel_order() {
         let _guard = setup_logging().unwrap();
         let mut client = get_client().await;
-        let response = client.cancel_order("1978656082822787072".to_string()).await;
+        let response = client.cancel_order("1978656082822787072").await;
 
         info!("{}", serde_json::to_string_pretty(&response.data).unwrap());
     }

+ 5 - 5
src/exchange/extended_stream_client.rs

@@ -31,7 +31,7 @@ pub struct ExtendedStreamClient {
 #[allow(dead_code)]
 impl ExtendedStreamClient {
     // ============================================= 构造函数 ================================================
-    fn new(label: String, account_option: Option<ExtendedAccount>, subscribe_pattern: String, is_testnet: bool) -> ExtendedStreamClient {
+    fn new(label: &str, account_option: Option<ExtendedAccount>, subscribe_pattern: String, is_testnet: bool) -> ExtendedStreamClient {
         let host = match is_testnet {
             true => {
                 "wss://api.starknet.sepolia.extended.exchange/stream.extended.exchange/v1/".to_string()  // testnet
@@ -44,7 +44,7 @@ impl ExtendedStreamClient {
         let address_url = format!("{}{}", host, subscribe_pattern);
 
         ExtendedStreamClient {
-            label,
+            label: label.to_string(),
             address_url,
             account_option,
             heartbeat_time: 1000 * 10,
@@ -52,11 +52,11 @@ impl ExtendedStreamClient {
     }
 
     // ============================================= 订阅函数 ================================================
-    pub fn order_books(label: String, account_option: Option<ExtendedAccount>, symbol: String, is_testnet: bool) -> ExtendedStreamClient {
+    pub fn order_books(label: &str, account_option: Option<ExtendedAccount>, symbol: &str, is_testnet: bool) -> ExtendedStreamClient {
         Self::new(label, account_option, format!("orderbooks/{}", symbol), is_testnet)
     }
     
-    pub fn best_prices(label: String, account_option: Option<ExtendedAccount>, symbol: String, is_testnet: bool) -> ExtendedStreamClient {
+    pub fn best_prices(label: &str, account_option: Option<ExtendedAccount>, symbol: &str, is_testnet: bool) -> ExtendedStreamClient {
         Self::new(label, account_option, format!("orderbooks/{}?depth=1", symbol), is_testnet)
     }
 
@@ -307,7 +307,7 @@ mod tests {
         let (write_tx, write_rx) = futures_channel::mpsc::unbounded::<Message>();
         let _guard = setup_logging().unwrap();
 
-        let mut ws = ExtendedStreamClient::order_books("Extended".to_string(), None, "BTC-USD".to_string(), true);
+        let mut ws = ExtendedStreamClient::order_books("Extended", None, "BTC-USD", true);
 
         let fun = move |response: Response| {
             info!("{}", serde_json::to_string_pretty(&response.data).unwrap());

+ 1 - 1
src/exchange/mod.rs

@@ -1,3 +1,3 @@
 pub mod extended_stream_client;
 pub mod extended_account;
-mod extended_rest_client;
+pub mod extended_rest_client;

+ 19 - 5
src/main.rs

@@ -13,6 +13,8 @@ use tokio_tungstenite::tungstenite::Message;
 use tracing::{error, info, warn};
 use utils::log_setup;
 use crate::data_manager::DataManager;
+use crate::exchange::extended_account::ExtendedAccount;
+use crate::exchange::extended_rest_client::ExtendedRestClient;
 use crate::exchange::extended_stream_client::ExtendedStreamClient;
 use crate::strategy::Strategy;
 use crate::utils::response::Response;
@@ -93,18 +95,30 @@ async fn main() {
 /// # Returns
 pub async fn run_extended_subscriptions(running: Arc<AtomicBool>) -> Result<()> {
     let is_testnet = false;
-    let symbol = "BTC-USD".to_string();
+    let market = "BTC-USD";
+    let account = ExtendedAccount::new(
+        "9ae4030902ab469a1bae8a90464e2e91",
+        "0x71e16e49b717b851ced8347cf0dfa8f490bfb826323b9af624a66285dc99672",
+        "0x47cdde8952945c13460f9129644eade096100810fba59de05452b34aacecff6",
+        220844,
+    );
+    let is_testnet = false;
 
+    // 订阅数据的客户端
     let stream_client_list = vec![
-        ExtendedStreamClient::best_prices(format!("ExtendedBestPrices_{}", symbol), None, symbol, is_testnet)
+        ExtendedStreamClient::best_prices(format!("ExtendedBestPrices_{}", market).as_str(), None, market, is_testnet)
     ];
+    
+    // rest客户端
+    let rest_client = ExtendedRestClient::new("ExtendedRestClient", Some(account), market, is_testnet).await?;
+    let rest_client_am = Arc::new(Mutex::new(rest_client));
 
     // 数据管理及消息分发
     let data_manager = DataManager::new();
     let data_manager_am = Arc::new(Mutex::new(data_manager));
 
     // 策略执行
-    let strategy = Strategy::new();
+    let strategy = Strategy::new(rest_client_am.clone());
     let strategy_am = Arc::new(Mutex::new(strategy));
 
     // 异步去订阅、并阻塞
@@ -122,8 +136,8 @@ pub async fn run_extended_subscriptions(running: Arc<AtomicBool>) -> Result<()>
             let dm_clone = Arc::clone(&dm);
             let sm_clone = Arc::clone(&sm);
             async move {
-                // 数据不新鲜直接跳过,300ms为阈值
-                if response.reach_time - response.received_time > 300 {
+                // 数据不新鲜直接跳过
+                if response.reach_time - response.received_time > 500 {
                     return
                 }
                 

+ 194 - 15
src/strategy.rs

@@ -1,10 +1,16 @@
-use anyhow::Result;
+use std::str::FromStr;
+use std::sync::Arc;
+use anyhow::{anyhow, bail, Result};
 use rust_decimal::Decimal;
 use std::time::{Duration, Instant};
 use rust_decimal_macros::dec;
+use serde_json::Value;
+use tokio::sync::Mutex;
 use tokio::time::sleep;
 use tracing::{info, warn};
 use crate::data_manager::DataManager;
+use crate::exchange::extended_rest_client::ExtendedRestClient;
+use crate::utils::response::Response;
 
 #[derive(Debug, Clone, PartialEq)]
 pub enum StrategyState {
@@ -27,18 +33,20 @@ pub enum StrategyState {
 #[allow(dead_code)]
 pub struct Strategy {
     state: StrategyState,
-    order_quantity: Decimal,            // 写死的订单数量
-    filled_quantity: Decimal,           // 成交数量
-    min_order_interval_ms: u128,        // 最小下单间隔(毫秒)
+    order_quantity: Decimal,                        // 写死的订单数量
+    filled_quantity: Decimal,                       // 成交数量
+    min_order_interval_ms: u128,                    // 最小下单间隔(毫秒)
+    rest_client: Arc<Mutex<ExtendedRestClient>>,    // rest客户端
 }
 
 impl Strategy {
-    pub fn new() -> Strategy {
+    pub fn new(client_am: Arc<Mutex<ExtendedRestClient>>) -> Strategy {
         Strategy {
             state: StrategyState::Idle,
             order_quantity: dec!(0.001),
             filled_quantity: Decimal::ZERO,
             min_order_interval_ms: 200,
+            rest_client: client_am,
         }
     }
 
@@ -159,9 +167,9 @@ impl Strategy {
         sleep(Duration::from_millis(3000)).await;
 
         // 撤单后检查是否有成交
-        match self.check_order_filled(&order_id).await {
+        match self.check_order_partially_filled(&order_id).await {
             Ok(true) => {
-                info!("撤单发现有成交,准备执行市价单");
+                info!("撤单发现有成交,准备执行市价单");
                 self.state = StrategyState::ExecutingMarketOrder { last_order_time };
                 Ok(())
             }
@@ -204,7 +212,10 @@ impl Strategy {
         match self.check_order_filled(&order_id).await {
             Ok(true) => {
                 info!("市价单已成交,返回空闲状态");
+
                 self.state = StrategyState::Idle;
+                self.filled_quantity = Decimal::ZERO;
+
                 Ok(())
             }
             Ok(false) => {
@@ -233,28 +244,196 @@ impl Strategy {
     /// 下限价买单
     async fn place_limit_buy_order(&self, price: Decimal, quantity: Decimal) -> Result<String> {
         info!("下限价买单: 价格={}, 数量={}", price, quantity);
-        // TODO: 实现具体的下单逻辑
-        Ok("order_id_placeholder".to_string())
+
+        let mut client = self.rest_client.lock().await;
+
+        // 调用client执行下单
+        let create_result = client.post_order(
+            "LIMIT",
+            "BUY",
+            quantity.to_string().as_str(),
+            price.to_string().as_str()).await;
+
+        // 解析下单结果并返回
+        self.match_create_order_result(&create_result)
     }
 
     /// 下市价卖单
     async fn place_market_sell_order(&self, quantity: Decimal) -> Result<String> {
         info!("下市价卖单: 数量={}", quantity);
-        // TODO: 实现具体的下单逻辑
-        Ok("order_id_placeholder".to_string())
+        let mut client = self.rest_client.lock().await;
+
+        // 调用client执行下单
+        let create_result = client.post_order(
+            "MARKET",
+            "SELL",
+            quantity.to_string().as_str(),
+            "-1").await;
+
+        // 解析下单结果并返回
+        self.match_create_order_result(&create_result)
     }
 
     /// 撤单
     async fn cancel_order(&self, order_id: &str) -> Result<()> {
         info!("撤单: {}", order_id);
-        // TODO: 实现具体的撤单逻辑
+
+        let mut client = self.rest_client.lock().await;
+
+        let response = client.cancel_order(order_id).await;
+        let value = &response.data;
+
+        // 预先捕获整个 Value 的字符串表示,用于错误报告
+        let value_str = serde_json::to_string(&value).unwrap_or_else(|_| "无法序列化 JSON Value".to_string());
+
+        // 获取status
+        let status = value.get("status")
+            .and_then(|v| v.as_str())
+            .ok_or_else(|| anyhow!("撤单-获取 'status' 失败,原始JSON:{}", value_str))?;
+
+        // 判定status
+        if status != "OK" {
+            bail!("撤单失败,状态不为OK,原始JSON:{}", value_str)
+        }
+
         Ok(())
     }
 
     /// 检查订单是否完全成交
-    async fn check_order_filled(&self, order_id: &str) -> Result<bool> {
+    async fn check_order_filled(&mut self, order_id: &str) -> Result<bool> {
         info!("检查订单是否成交: {}", order_id);
-        // TODO: 实现具体的查询逻辑
-        Ok(false)
+
+        let query_result = self.get_order_result(order_id).await;
+
+        match query_result {
+            Ok(data) => {
+                let data_str = serde_json::to_string(&data).unwrap_or_else(|_| "无法序列化 JSON Value".to_string());
+
+                // 获取order的状态[NEW, PARTIALLY_FILLED, FILLED, UNTRIGGERED, CANCELLED, REJECTED, EXPIRED, TRIGGERED]
+                let status = data.get("status")
+                    .and_then(|v| v.as_str())
+                    .ok_or_else(|| anyhow!("查单-获取 'data.status' 失败,原始JSON:{}", data_str))?;
+                
+                // 只考虑完全成交的
+                if status == "FILLED" {
+                    // 获取真实成交数量
+                    let filled_qty = data.get("filledQty")
+                        .and_then(|v| v.as_str())
+                        .ok_or_else(|| anyhow!("查单-获取 'data.filledQty' 失败,原始JSON:{}", data_str))
+                        .and_then(|v| Decimal::from_str(v)
+                            .map_err(|e| anyhow!("查单-解析 'data.filledQty' 为 Decimal 失败: {}, 值: {}", e, v))
+                        )?;
+
+                    self.filled_quantity = filled_qty;
+
+                    Ok(true)
+                } else {
+                    Ok(false)
+                }
+            }
+            Err(error) => {
+                bail!("查单失败: {}", error);
+            }
+        }
+    }
+
+    /// 检查订单是否有部分成交
+    async fn check_order_partially_filled(&mut self, order_id: &str) -> Result<bool> {
+        info!("检查订单是否有部分成交: {}", order_id);
+
+        let query_result = self.get_order_result(order_id).await;
+        
+        match query_result {
+            Ok(data) => {
+                let data_str = serde_json::to_string(&data).unwrap_or_else(|_| "无法序列化 JSON Value".to_string());
+
+                // 获取order的状态[NEW, PARTIALLY_FILLED, FILLED, UNTRIGGERED, CANCELLED, REJECTED, EXPIRED, TRIGGERED]
+                let status = data.get("status")
+                    .and_then(|v| v.as_str())
+                    .ok_or_else(|| anyhow!("查单-获取 'data.status' 失败,原始JSON:{}", data_str))?;
+
+                // 只考虑完全成交的
+                if status == "FILLED" || status == "PARTIALLY_FILLED" {
+                    // 获取真实成交数量
+                    let filled_qty = data.get("filledQty")
+                        .and_then(|v| v.as_str())
+                        .ok_or_else(|| anyhow!("查单-获取 'data.filledQty' 失败,原始JSON:{}", data_str))
+                        .and_then(|v| Decimal::from_str(v)
+                            .map_err(|e| anyhow!("查单-解析 'data.filledQty' 为 Decimal 失败: {}, 值: {}", e, v))
+                        )?;
+
+                    self.filled_quantity = filled_qty;
+
+                    Ok(true)
+                } else {
+                    Ok(false)
+                }
+            }
+            Err(error) => {
+                bail!("查单失败: {}", error);
+            }
+        }
+    }
+
+    async fn get_order_result(&self, id: &str) -> Result<Value> {
+        let mut client = self.rest_client.lock().await;
+        let response = client.get_order(id).await;
+
+        let value = &response.data;
+
+        // 预先捕获整个 Value 的字符串表示,用于错误报告
+        let value_str = serde_json::to_string(&value).unwrap_or_else(|_| "无法序列化 JSON Value".to_string());
+
+        // 获取status
+        let status = value.get("status")
+            .and_then(|v| v.as_str())
+            .ok_or_else(|| anyhow!("查单-获取 'status' 失败,原始JSON:{}", value_str))?;
+
+        // 判定status
+        if status != "OK" {
+            bail!("查单失败,状态不为OK,原始JSON:{}", value_str)
+        }
+
+        // 尝试获取 data 字段
+        let data = value.get("data")
+            .ok_or_else(|| anyhow!("下单-获取 'data' 字段失败,原始 JSON: {}", value_str))?;
+
+        // data返回给查单
+        Ok(data.clone())
+    }
+
+    fn match_create_order_result(&self, create_result: &Result<Response>) -> Result<String> {
+        match create_result {
+            Ok(response) => {
+                let value = &response.data;
+
+                // 预先捕获整个 Value 的字符串表示,用于错误报告
+                let value_str = serde_json::to_string(&value).unwrap_or_else(|_| "无法序列化 JSON Value".to_string());
+
+                // 获取status
+                let status = value.get("status")
+                    .and_then(|v| v.as_str())
+                    .ok_or_else(|| anyhow!("下单-获取 'status' 失败,原始JSON:{}", value_str))?;
+
+                // 判定status
+                if status != "OK" {
+                    bail!("下单失败,状态不为OK,原始JSON:{}", value_str)
+                }
+
+                // 尝试获取 data 字段
+                let data = value.get("data")
+                    .ok_or_else(|| anyhow!("下单-获取 'data' 字段失败,原始 JSON: {}", value_str))?;
+
+                // 获取order的id
+                let id = data.get("id")
+                    .and_then(|v| v.as_str())
+                    .ok_or_else(|| anyhow!("下单-获取 'data.id' 失败,原始JSON:{}", value_str))?;
+
+                Ok(id.to_string())
+            }
+            Err(error) => {
+                bail!("下单失败:{}", error);
+            }
+        }
     }
 }