Jelajahi Sumber

异步闭包调用,完美完成,数据到本地可以立即执行。

skyfffire 2 tahun lalu
induk
melakukan
45e056a1a2
3 mengubah file dengan 307 tambahan dan 235 penghapusan
  1. 119 99
      src/exchange_libs.rs
  2. 41 24
      src/exchange_middle_ware.rs
  3. 147 112
      src/main.rs

+ 119 - 99
src/exchange_libs.rs

@@ -1,9 +1,11 @@
 use std::collections::{BTreeMap, HashMap};
 use std::{env, io, thread};
+use std::future::Future;
 use std::io::{Write};
 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
-use std::ptr::null;
 use std::str::FromStr;
+use std::sync::Arc;
+use tokio::sync::Mutex;
 use std::time::Duration;
 use chrono::Utc;
 use reqwest;
@@ -19,7 +21,7 @@ use url::Url;
 
 #[cfg(test)]
 mod tests {
-    use crate::exchange_libs::{BinanceExc, is_proxy, OkxExc, ReqData};
+    use crate::exchange_libs::{BinanceExc, is_proxy, OkxExc, ResponseData};
 
     #[tokio::test]//测试获取ok 账号信息
     async fn test_okx_get_acc() {
@@ -30,7 +32,7 @@ mod tests {
             "556DAB6773CA26DDAAA114F7044138CA".to_string(),
             "rust_Test123".to_string(),
         );
-        let req: ReqData = okx_exc.okx_acc("USDT").await;
+        let req: ResponseData = okx_exc.okx_acc("USDT").await;
         print!("---响应:code:{}", req.code);
         print!("---响应:mes:{}", req.message);
         print!("---响应:data:{}", req.data);
@@ -45,7 +47,7 @@ mod tests {
             "556DAB6773CA26DDAAA114F7044138CA".to_string(),
             "rust_Test123".to_string(),
         );
-        let req: ReqData = okx_exc.okx_order("CORE-USDT", "cash", "buy", "limit", "0.8555", "1").await;
+        let req: ResponseData = okx_exc.okx_order("CORE-USDT", "cash", "buy", "limit", "0.8555", "1").await;
         print!("---响应:code:{}", req.code);
         print!("---响应:mes:{}", req.message);
         print!("---响应:data:{}", req.data);
@@ -60,7 +62,7 @@ mod tests {
             "556DAB6773CA26DDAAA114F7044138CA".to_string(),
             "rust_Test123".to_string(),
         );
-        let req2: ReqData = okx_exc.okx_get_order("CORE-USDT", "611949361383612427").await;
+        let req2: ResponseData = okx_exc.okx_get_order("CORE-USDT", "611949361383612427").await;
         print!("---响应:code:{}", req2.code);
         print!("---响应:mes:{}", req2.message);
         print!("---响应:data:{}", req2.data);
@@ -75,7 +77,7 @@ mod tests {
             "556DAB6773CA26DDAAA114F7044138CA".to_string(),
             "rust_Test123".to_string(),
         );
-        let req3: ReqData = okx_exc.okx_revocation_order("CORE-USDT", "611950727669751811").await;
+        let req3: ResponseData = okx_exc.okx_revocation_order("CORE-USDT", "611950727669751811").await;
         print!("---响应:code:{}", req3.code);
         print!("---响应:mes:{}", req3.message);
         print!("---响应:data:{}", req3.data);
@@ -89,7 +91,7 @@ mod tests {
             "a4cf4f54-f4d3-447d-a57c-166fd1ead2e0".to_string(),
             "556DAB6773CA26DDAAA114F7044138CA".to_string(),
         );
-        let req3: ReqData = binance_exc.binance_k("BTCUSDT", "5m", "20").await;
+        let req3: ResponseData = binance_exc.binance_k("BTCUSDT", "5m", "20").await;
         print!("---响应:code:{}", req3.code);
         print!("---响应:mes:{}", req3.message);
         print!("---响应:data:{}", req3.data);
@@ -102,7 +104,7 @@ mod tests {
             "a4cf4f54-f4d3-447d-a57c-166fd1ead2e0".to_string(),
             "556DAB6773CA26DDAAA114F7044138CA".to_string(),
         );
-        let req3: ReqData = binance_exc.binance_depth("BTCUSDT", "20").await;
+        let req3: ResponseData = binance_exc.binance_depth("BTCUSDT", "20").await;
         print!("---响应:code:{}", req3.code);
         print!("---响应:mes:{}", req3.message);
         print!("---响应:data:{}", req3.data);
@@ -121,7 +123,7 @@ impl BinanceExc {
     }
 
     //币安-深度信息
-    pub async fn binance_depth(&self, symbol: &str, limit: &str) -> ReqData {
+    pub async fn binance_depth(&self, symbol: &str, limit: &str) -> ResponseData {
         let base_url = "/api/v3/depth?symbol=".to_string() + &symbol + "&limit=" + &limit;
 
         let result = self.get(base_url.to_string()).await;
@@ -130,14 +132,14 @@ impl BinanceExc {
                 req_data
             }
             Err(err) => {
-                let error = ReqData::error(format!("json 解析失败:{}", err));
+                let error = ResponseData::error(format!("json 解析失败:{}", err));
                 error
             }
         }
     }
 
     //币安-k线
-    pub async fn binance_k(&self, symbol: &str, interval: &str, limit: &str) -> ReqData {
+    pub async fn binance_k(&self, symbol: &str, interval: &str, limit: &str) -> ResponseData {
         let base_url = "/api/v3/klines?symbol=".to_string() + &symbol + "&interval=" + &interval + "&limit=" + &limit;
 
         let result = self.get(base_url.to_string()).await;
@@ -146,14 +148,14 @@ impl BinanceExc {
                 req_data
             }
             Err(err) => {
-                let error = ReqData::error(format!("json 解析失败:{}", err));
+                let error = ResponseData::error(format!("json 解析失败:{}", err));
                 error
             }
         }
     }
 
     //普通get
-    async fn get(&self, request_path: String) -> Result<(ReqData), reqwest::Error> {
+    async fn get(&self, request_path: String) -> Result<(ResponseData), reqwest::Error> {
         let mut req_data;
         let base_url = self.base_url.clone();
         // 发起GET请求
@@ -166,11 +168,11 @@ impl BinanceExc {
             // 读取响应的内容
             let body = response.text().await?;
             println!("Response body:\n{}", body);
-            req_data = ReqData::new("0".to_string(), "success".to_string(), body);
+            req_data = ResponseData::new("0".to_string(), "success".to_string(), body);
         } else {
             let body = response.text().await?;
             println!("Request failed with status: {}", body);
-            req_data = ReqData::error(body.to_string())
+            req_data = ResponseData::error(body.to_string())
         }
         Ok((req_data))
     }
@@ -191,7 +193,7 @@ impl OkxExc {
     }
 
     //获取订单信息
-    pub async fn okx_get_order(&self, inst_id: &str, ord_id: &str) -> ReqData {
+    pub async fn okx_get_order(&self, inst_id: &str, ord_id: &str) -> ResponseData {
         let mut btree_map: BTreeMap<&str, &str> = BTreeMap::new();
         btree_map.insert("instId", inst_id);//产品Id
         btree_map.insert("ordId", ord_id);//顶顶那
@@ -205,7 +207,7 @@ impl OkxExc {
     }
 
     //撤单接口
-    pub async fn okx_revocation_order(&self, inst_id: &str, ord_id: &str) -> ReqData {
+    pub async fn okx_revocation_order(&self, inst_id: &str, ord_id: &str) -> ResponseData {
         let mut btree_map: BTreeMap<&str, &str> = BTreeMap::new();
         btree_map.insert("instId", inst_id);//产品Id
         btree_map.insert("ordId", ord_id);//顶顶那
@@ -219,7 +221,7 @@ impl OkxExc {
     }
 
     //下单接口
-    pub async fn okx_order(&self, inst_id: &str, td_mode: &str, side: &str, ord_type: &str, px: &str, sz: &str) -> ReqData {
+    pub async fn okx_order(&self, inst_id: &str, td_mode: &str, side: &str, ord_type: &str, px: &str, sz: &str) -> ResponseData {
         let mut btree_map: BTreeMap<&str, &str> = BTreeMap::new();
         btree_map.insert("instId", inst_id);//产品Id
         btree_map.insert("tdMode", td_mode);//交易模式
@@ -237,7 +239,7 @@ impl OkxExc {
     }
 
     //账户信息
-    pub async fn okx_acc(&self, ccy: &str) -> ReqData {
+    pub async fn okx_acc(&self, ccy: &str) -> ResponseData {
         let mut btree_map: BTreeMap<&str, &str> = BTreeMap::new();
         btree_map.insert("ccy", ccy);
 
@@ -250,8 +252,8 @@ impl OkxExc {
     }
 
     //带认证-get
-    pub(crate) async fn get_v(&self, request_path: String, params: BTreeMap<&str, &str>) -> Result<(ReqData), reqwest::Error> {
-        let mut req_data: ReqData;
+    pub(crate) async fn get_v(&self, request_path: String, params: BTreeMap<&str, &str>) -> Result<(ResponseData), reqwest::Error> {
+        let mut req_data: ResponseData;
 
         /*请求接口与 地址*/
         let base_url = self.base_url.to_string();
@@ -289,18 +291,18 @@ impl OkxExc {
             // 读取响应的内容
             let body = response.text().await?;
             println!("okx_acc-Response body:\n{}", body);
-            req_data = ReqData::new("0".to_string(), "success".to_string(), body);
+            req_data = ResponseData::new("0".to_string(), "success".to_string(), body);
         } else {
             let body = response.text().await?;
             println!("okx_acc-Request failed with status: {}", body);
-            req_data = ReqData::error(body.to_string())
+            req_data = ResponseData::error(body.to_string())
         }
         Ok((req_data))
     }
 
     //带认证-post
-    async fn post_v(&self, request_path: String, params: BTreeMap<&str, &str>) -> Result<(ReqData), reqwest::Error> {
-        let mut req_data: ReqData;
+    async fn post_v(&self, request_path: String, params: BTreeMap<&str, &str>) -> Result<(ResponseData), reqwest::Error> {
+        let mut req_data: ResponseData;
 
         /*请求接口与 地址*/
         let base_url = self.base_url.to_string();
@@ -345,11 +347,11 @@ impl OkxExc {
             // 读取响应的内容
             let body = response.text().await?;
             println!("okx_order-Response body:\n{}", body);
-            req_data = ReqData::new("0".to_string(), "success".to_string(), body);
+            req_data = ResponseData::new("0".to_string(), "success".to_string(), body);
         } else {
             let body = response.text().await?;
             println!("okx_order-Request failed with status: {}", body);
-            req_data = ReqData::error(body.to_string())
+            req_data = ResponseData::error(body.to_string())
         }
         Ok((req_data))
     }
@@ -395,7 +397,7 @@ impl OkxExc {
     }
 
     //req_data 解析
-    fn req_data_analysis(&self, result: Result<ReqData, reqwest::Error>) -> ReqData {
+    fn req_data_analysis(&self, result: Result<ResponseData, reqwest::Error>) -> ResponseData {
         match result {
             Ok(req_data) => {
                 if req_data.code != "0" {
@@ -411,14 +413,14 @@ impl OkxExc {
                     // println!("--解析成功----code:{}",code);
                     // println!("--解析成功----data:{}",data);
                     // println!("--解析成功----msg:{}",msg);
-                    let success = ReqData::new(code.parse().unwrap(),
+                    let success = ResponseData::new(code.parse().unwrap(),
                                                msg.parse().unwrap(),
                                                data.parse().unwrap());
                     success
                 }
             }
             Err(err) => {
-                let error = ReqData::error(format!("json 解析失败:{}", err));
+                let error = ResponseData::error(format!("json 解析失败:{}", err));
                 error
             }
         }
@@ -465,19 +467,19 @@ fn get_timestamp() -> String {
 
 //统一返回
 #[derive(Debug)]
-pub struct ReqData {
+pub struct ResponseData {
     pub code: String,
     pub message: String,
     pub data: String,
 }
 
-impl ReqData {
-    pub fn new(code: String, message: String, data: String) -> ReqData {
+impl ResponseData {
+    pub fn new(code: String, message: String, data: String) -> ResponseData {
         // original_string.replace("world", "Rust");
-        ReqData { code, message, data }
+        ResponseData { code, message, data }
     }
-    pub fn error(message: String) -> ReqData {
-        ReqData { code: "-1".to_string(), message: "请求失败:".to_string() + &message, data: "".to_string() }
+    pub fn error(message: String) -> ResponseData {
+        ResponseData { code: "-1".to_string(), message: "请求失败:".to_string() + &message, data: "".to_string() }
     }
 }
 
@@ -539,33 +541,41 @@ impl SocketTool {
             subscription: subscription,
         }
     }
-    pub(crate) fn run<F: FnMut(ReqData)>(&self, mut parse_fn: F) {
-        while true {//一个粗糙的 断开重连操作
-            /*****消息溜***/
-            let mut stdout = io::stdout();
-            let mut stderr = io::stderr();
-
-            /*****socket配置信息***/
-            let request_url = Url::parse(self.request_url.as_str()).unwrap();
-            let ip_array: Vec<&str> = self.ip.split(".").collect();
-            let proxy_address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(
-                ip_array[0].parse().unwrap(),
-                ip_array[1].parse().unwrap(),
-                ip_array[2].parse().unwrap(),
-                ip_array[3].parse().unwrap())
-            ), self.port);
-            let websocket_config = Some(WebSocketConfig {
-                max_send_queue: Some(16),
-                max_message_size: Some(16 * 1024 * 1024),
-                max_frame_size: Some(16 * 1024 * 1024),
-                accept_unmasked_frames: false,
-            });
-            let max_redirects = 5;
+    pub(crate) fn run<F, Fut>(&self, parse_fn: F)
+    where
+        F: Fn(ResponseData) -> Fut + 'static + Send + Sync,
+        Fut: Future<Output = ()> + Send + 'static,
+    {
+        /*****消息溜***/
+        let mut stdout = io::stdout();
+        let mut stderr = io::stderr();
+
+        /*****socket配置信息***/
+        let request_url = Url::parse(self.request_url.as_str()).unwrap();
+        let ip_array: Vec<&str> = self.ip.split(".").collect();
+        let proxy_address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(
+            ip_array[0].parse().unwrap(),
+            ip_array[1].parse().unwrap(),
+            ip_array[2].parse().unwrap(),
+            ip_array[3].parse().unwrap())
+        ), self.port);
+        let websocket_config = Some(WebSocketConfig {
+            max_send_queue: Some(16),
+            max_message_size: Some(16 * 1024 * 1024),
+            max_frame_size: Some(16 * 1024 * 1024),
+            accept_unmasked_frames: false,
+        });
+        let max_redirects = 5;
+
+        let parse_fn = Arc::new(Mutex::new(parse_fn)); // Wrap the closure in an Arc<Mutex<_>>
+
+        // 一个粗糙的 断开重连操作
+        // loop {
             /*****判断代理IP是否为空,空则不走代理*****/
             if self.ip.len() > 0 {
                 println!("----socket-走代理");
                 let (mut socket, response) =
-                    connect_with_proxy(request_url, proxy_address, websocket_config, max_redirects)
+                    connect_with_proxy(request_url.clone(), proxy_address, websocket_config, max_redirects)
                         .expect("Can't connect(无法连接)");
 
                 /******登陆认证********/
@@ -590,6 +600,8 @@ impl SocketTool {
                 /******数据读取********/
                 loop {
                     if !socket.can_read() {
+                        println!("不能读取的socket");
+
                         continue;
                     }
 
@@ -604,45 +616,30 @@ impl SocketTool {
                                 writeln!(stdout, "-币安--订阅成功:{0}", text).unwrap();
                             } else if json_value.get("event").is_some() {
                                 writeln!(stdout, "-OKX--订阅成功:{0}", text).unwrap();
-                                // let event_v = serde_json::to_string(&json_value["event"]).unwrap();
-                                // println!("---字符串比较:{0}---{1}", event_v,(event_v.to_string() == "login".to_string()));
-                                // if event_v == "subscribe" {
-                                //     writeln!(stdout, "-OKX--订阅成功:{:?}", text).unwrap();
-                                // } else if event_v == "error" {
-                                //     writeln!(stdout, "-OKX--订阅失败:{:?}", text).unwrap();
-                                //     let code_v = serde_json::to_string(&json_value["code"]).unwrap();
-                                //     let msg_v = serde_json::to_string(&json_value["msg"]).unwrap();
-                                //
-                                //     let reqData = ReqData::new(code_v.to_string(),
-                                //                                msg_v.to_string(),
-                                //                                "".to_string());
-                                //     parse_fn(reqData);
-                                // } else if event_v == "login".to_string() {
-                                //     let code_v = serde_json::to_string(&json_value["code"]).unwrap();
-                                //     if code_v == "0".to_string() {
-                                //         writeln!(stdout, "---登陆成功:{0}", text).unwrap();
-                                //     } else {
-                                //         writeln!(stdout, "---登陆失败:{0}", text).unwrap();
-                                //     }
-                                // }
                             } else {
                                 //便 --推送数据
                                 // writeln!(stdout, "---推送数据:{0}", text).unwrap();
-                                let reqData = ReqData::new("0".to_string(),
+                                let rsp_data = ResponseData::new("0".to_string(),
                                                            "success".to_string(),
                                                            text);
-                                parse_fn(reqData);
+                                // parse_fn(rsp_data);
+                                let parse_fn = Arc::clone(&parse_fn); // Clone the Arc for each iteration
+
+                                tokio::spawn(async move {
+                                    let parse_fn = parse_fn.lock().await;
+                                    parse_fn(rsp_data).await;
+                                });
                             }
                         }
                         Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Close(_)) => {
-                            socket.write_message(Message::text("pong"));
+                            socket.write_message(Message::text("pong")).expect("TODO: panic message");
                             // writeln!(stdout, "ping----------pong").unwrap();
                             writeln!(stdout, "ping----------pong").unwrap();
                         }
                         Err(error) => {
                             writeln!(stderr, "Error receiving message: {}", error).unwrap();
-                            let reqData = ReqData::error("socket 发生错误!".to_string());
-                            parse_fn(reqData);
+                            let rsp_data = ResponseData::error("socket 发生错误!".to_string());
+                            // parse_fn(rsp_data);
                             break;
                         }
                         _ => {}
@@ -654,7 +651,7 @@ impl SocketTool {
                 // 提示,并未找到好的优化方式,
                 println!("----socket-没代理");
                 let (mut socket, response) =
-                    connect(request_url)
+                    connect(request_url.clone())
                         .expect("Can't connect(无法连接)");
 
                 /******登陆认证********/
@@ -692,10 +689,10 @@ impl SocketTool {
                                 writeln!(stdout, "---订阅成功:{:?}", text).unwrap();
                             } else {
                                 writeln!(stdout, "---推送数据:{0}", text).unwrap();
-                                let reqData = ReqData::new("0".to_string(),
+                                let rsp_data = ResponseData::new("0".to_string(),
                                                            "success".to_string(),
                                                            text);
-                                parse_fn(reqData);
+                                // parse_fn(rsp_data);
                             }
                         }
                         Ok(Message::Ping(_)) | Ok(Message::Pong(_)) | Ok(Message::Close(_)) => {
@@ -705,9 +702,11 @@ impl SocketTool {
                         }
                         Err(error) => {
                             writeln!(stderr, "Error receiving message: {}", error).unwrap();
-                            let reqData = ReqData::error("socket 发生错误!".to_string());
-                            parse_fn(reqData);
-                            break;
+                            // let rsp_data = ResponseData::error("socket 发生错误!".to_string());
+                            // tokio::spawn(async move {
+                            //     parse_fn(rsp_data).await;
+                            // });
+                            // break;
                         }
                         _ => {}
                     }
@@ -715,7 +714,7 @@ impl SocketTool {
 
                 socket.close(None).unwrap();
             }
-        }
+        // }
     }
     fn log_in_to_str(&self) -> String {
         let mut login_json_str = String::from("");
@@ -771,22 +770,34 @@ impl SocketTool {
     }
 
     //币安--自定义-订阅
-    pub fn binance_run_custom(b_array: Vec<&str>, parse_fn: impl Fn(ReqData)) {}
+    pub fn binance_run_custom(b_array: Vec<&str>, parse_fn: impl Fn(ResponseData)) {}
 
 
     //币安--深度信息
-    pub fn binance_run_kline(b_array: Vec<&str>, parse_fn: impl Fn(ReqData)) {
+    pub fn binance_run_kline<F, Fut>(b_array: Vec<&str>, parse_fn: F)
+    where
+        F: Fn(ResponseData) -> Fut + 'static + Send + Sync,
+        Fut: Future<Output = ()> + Send + 'static,
+    {
         SocketTool::binance_run(b_array, "kline_1s".to_string(), parse_fn);
     }
 
     //币安--深度信息
-    pub fn binance_run_depth<F: FnMut(ReqData)>(b_array: Vec<&str>, levels: String, parse_fn: F) {
+    pub fn binance_run_depth<F, Fut>(b_array: Vec<&str>, levels: String, parse_fn: F)
+    where
+        F: Fn(ResponseData) -> Fut + 'static + Send + Sync,
+        Fut: Future<Output = ()> + Send + 'static,
+    {
         let str = format!("depth{}@100ms", levels);
         SocketTool::binance_run(b_array, str.to_string(), parse_fn);
     }
 
     //币安--订阅
-    pub fn binance_run<F: FnMut(ReqData)>(b_array: Vec<&str>, subscription_name: String, parse_fn: F) {
+    pub fn binance_run<F, Fut>(b_array: Vec<&str>, subscription_name: String, parse_fn: F)
+    where
+        F: Fn(ResponseData) -> Fut + 'static + Send + Sync,
+        Fut: Future<Output = ()> + Send + 'static,
+    {
         let mut params = vec![];
 
         for item in &b_array {
@@ -817,7 +828,11 @@ impl SocketTool {
 
 
     //OKX-私有频道-订单信息
-    pub fn okx_pr_run_orders<F: FnMut(ReqData)>(b_array: Vec<&str>, btree_map: BTreeMap<String, String>, parse_fn: F) {
+    pub fn okx_pr_run_orders<F, Fut>(b_array: Vec<&str>, btree_map: BTreeMap<String, String>, parse_fn: F)
+    where
+        F: Fn(ResponseData) -> Fut + 'static + Send + Sync,
+        Fut: Future<Output = ()> + Send + 'static,
+    {
         //组装推送信息
         let mut args = vec![];
         for item in &b_array {
@@ -833,8 +848,13 @@ impl SocketTool {
     }
 
     //OKX-私有频道-订阅
-    pub fn okx_pr_run<F: FnMut(ReqData)>(b_array: Vec<&str>, args: Vec<HashMap<String, String>>, btree_map: BTreeMap<String, String>, parse_fn: F) {
-        let url = "wss://ws.okx.com:8443/ws/v5/private";
+    pub fn okx_pr_run<F, Fut>(b_array: Vec<&str>, args: Vec<HashMap<String, String>>, btree_map: BTreeMap<String, String>, parse_fn: F)
+    where
+        F: Fn(ResponseData) -> Fut + 'static + Send + Sync,
+        Fut: Future<Output = ()> + Send + 'static,
+    {
+
+    let url = "wss://ws.okx.com:8443/ws/v5/private";
 
         let pu = "wss://ws.okx.com:8443/ws/v5/public";
         let subscription = json!({

+ 41 - 24
src/exchange_middle_ware.rs

@@ -1,7 +1,14 @@
 use std::io::{BufRead, Error, ErrorKind};
 use std::collections::{BTreeMap};
+use std::future::Future;
+use std::io;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::time::Duration;
 use serde_json::json;
-use crate::exchange_libs::{BinanceExc, OkxExc, ReqData, SocketTool};
+use tokio::sync::Mutex;
+use crate::Bot;
+use crate::exchange_libs::{BinanceExc, OkxExc, ResponseData, SocketTool};
 
 // 深度结构体
 #[derive(Debug)]
@@ -119,23 +126,33 @@ impl Exchange {
     // 获取币安深度信息
     // symbol: 交易币对, "BTC_USDT"
     // limit: 返回条数, 最大 5000. 可选值:[5, 10, 20, 50, 100, 500, 1000, 5000]
-    pub async fn get_binance_depth<F: FnMut(Depth)>(&self, symbol: &String, limit: i32, mut callback: F) {
+    pub async fn get_binance_depth(&self, symbol: &String, limit: i32, mut bot_arc: Arc<Mutex<Bot>>) {
         let real_symbol = self.get_real_symbol(symbol, "".to_string());
-        let get_res_data =|res_data:ReqData|{
-            if res_data.code == "0" {
-                let res_data_str = res_data.data;
-                let res_data_json: serde_json::Value = serde_json::from_str(&*res_data_str).unwrap();
-                let depth_asks: Vec<DepthItem> = parse_depth_items(&res_data_json["asks"]);
-                let depth_bids: Vec<DepthItem> = parse_depth_items(&res_data_json["bids"]);
-                let result = Depth {
-                    asks: depth_asks,
-                    bids: depth_bids,
-                };
-                callback(result)
-            } else {
-                panic!("get_binance_depth: {}",res_data.message);
+
+        let get_res_data = move |res_data: ResponseData| {
+            let bot_arc_clone = Arc::clone(&bot_arc);
+
+            async move {
+                if res_data.code == "0" {
+                    let res_data_str = res_data.data;
+                    let res_data_json: serde_json::Value = serde_json::from_str(&*res_data_str).unwrap();
+                    let depth_asks: Vec<DepthItem> = parse_depth_items(&res_data_json["asks"]);
+                    let depth_bids: Vec<DepthItem> = parse_depth_items(&res_data_json["bids"]);
+                    let result = Depth {
+                        asks: depth_asks,
+                        bids: depth_bids,
+                    };
+
+                    {
+                        let bot = bot_arc_clone.lock().await;
+                        bot.depth_handler(result)
+                    }
+                } else {
+                    panic!("get_binance_depth: {}", res_data.message);
+                }
             }
         };
+
         SocketTool::binance_run_depth(vec![&real_symbol], limit.to_string(), get_res_data)
     }
 
@@ -323,7 +340,7 @@ fn parse_depth_items(value: &serde_json::Value) -> Vec<DepthItem> {
 mod tests {
     use std::io::{self, Write};
     use crate::exchange_middle_ware::{Depth, Exchange};
-    use crate::exchange_libs::{is_proxy, ReqData};
+    use crate::exchange_libs::{is_proxy, ResponseData};
 
     // new Exchange
     fn new_exchange() -> Exchange {
@@ -348,14 +365,14 @@ mod tests {
     }
 
     // 测试binance获取深度信息
-    #[tokio::test]
-    async fn test_get_binance_depth() {
-        let exchange = new_exchange();
-        let get_depth_fn = |depth:Depth|{
-            writeln!(io::stdout(), "test_get_binance_depth:{:?}", depth).unwrap();
-        };
-        exchange.get_binance_depth(&"DOGE_USDT".to_string(), 10,get_depth_fn).await;
-    }
+    // #[tokio::test]
+    // async fn test_get_binance_depth() {
+    //     let exchange = new_exchange();
+    //     let get_depth_fn = |depth:Depth|{
+    //         writeln!(io::stdout(), "test_get_binance_depth:{:?}", depth).unwrap();
+    //     };
+    //     exchange.get_binance_depth(&"DOGE_USDT".to_string(), 10,get_depth_fn).await;
+    // }
 
     // 测试binance获取k线
     #[tokio::test]

+ 147 - 112
src/main.rs

@@ -1,13 +1,13 @@
 use std::{env, io, thread};
-use std::thread::sleep;
-use std::time::Duration;
-use chrono::format::Item::Error;
 use ndarray::prelude::*;
 use rust_decimal::prelude::{ToPrimitive};
 use crate::as_libs::*;
 use crate::exchange_middle_ware::{Account, Depth, Exchange, Order, Record};
 use time::OffsetDateTime;
-use crate::exchange_libs::is_proxy;
+use std::sync::Arc;
+use std::time::Duration;
+use tokio::sync::Mutex;
+
 
 mod as_libs;
 mod exchange_libs;
@@ -43,7 +43,7 @@ struct OrderInfo {
     time_num: i64
 }
 
-struct Bot {
+pub struct Bot {
     spread_list: Vec<f64>,
     symbol: String,
     limit: i32,
@@ -89,107 +89,6 @@ impl Bot {
         }
     }
 
-    fn depth_handler(&mut self, depth: &Depth){
-        let (spread, mid_price, ask, bid) = get_spread(depth);
-        self.mid_price = mid_price;
-        self.ask = ask;
-        self.bid = bid;
-        if self.spread_list.len() > self.spread_list_limit {
-            self.spread_list.remove(0);
-        }
-        self.spread_list.push(spread);
-
-        println!("depth handler {:?}", depth);
-    }
-
-    async fn start(&mut self){
-        // 使用std::env::var函数获取环境变量的值
-        let okx_access_key= env::var("okx_access_key").unwrap();
-        let okx_secret_key= env::var("okx_secret_key").unwrap();
-        let okx_passphrase= env::var("okx_passphrase").unwrap();
-
-        let exchange:Exchange = Exchange::new(okx_access_key, okx_secret_key, okx_passphrase);
-        let bot = self;
-        let symbol_clone = bot.symbol.clone();
-        let get_depth_fn = |depth:Depth| Self::depth_handler(&mut *bot, &depth);
-        exchange.get_binance_depth(&symbol_clone, 10, get_depth_fn).await;
-
-        loop {
-            let f = match bot.do_logic(&exchange).await {
-                Ok(m) => m,
-                Err(error) => {
-                    eprintln!("异常出现,捕获异常: {}", error);
-                    1
-                },
-            };
-        }
-    }
-
-    async fn do_logic(&mut self, exchange: &Exchange) -> Result<i8, io::Error>{
-        if self.spread_list.len() < self.spread_list_limit {
-            println!("等待spread初始化:{}/{}.", self.spread_list.len(), self.spread_list_limit);
-            return Ok(1_i8)
-        }
-
-        // 使用 max_by 方法和 partial_cmp 进行比较
-        let max_option = self.spread_list.iter().max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
-        match max_option {
-            Some(max_value) => {  },
-            None => eprintln!("列表为空"),
-        }
-        let max_spread =  *max_option.unwrap();
-
-        // 1.获取账户信息
-        let balance_info = match exchange.get_okx_account(&self.symbol).await {
-            Ok(m) => m,
-            Err(error) => {
-                return Err(error);
-            }
-        };
-        println!("info: {:?}", balance_info);
-        // 取消超时订单
-        self.order_list_deal(&balance_info.stocks, exchange).await.expect("订单信息处理异常!");
-
-        // 2.获取最新k线
-        let k_line_data = match  exchange.get_binance_klines(&self.symbol.to_string(), &self.short_interval, &200).await {
-            Ok(m) => m,
-            Err(error) => {
-                return Err(error);
-            }
-        };
-        // rl_list 计算
-        let (q, rl_list) = self.calc_params(&k_line_data, balance_info, max_spread);
-
-        let rate :f64 = (q/self.quantity_max) * 100.0;
-        let index_f :f64 = rl_list.len().to_f64().unwrap() / 100.0 * rate;
-        let index = index_f.round().to_usize().unwrap();
-
-        let order_amount = rl_list.get(index).unwrap().order_amount;
-        let order_dict :OrderDict = OrderDict{
-            order_amount,
-            buy_price: truncate_decimal_places(rl_list.get(index).unwrap().bid, self.price_decimal_places),
-            sell_price: truncate_decimal_places(rl_list.get(index).unwrap().ask, self.price_decimal_places)
-        };
-
-        let now_time = OffsetDateTime::now_utc().unix_timestamp();
-        // 检测交易间隔,发起交易
-        if self.order_info_list.len() > 0 || self.last_buy_time + self.buy_time_limit > now_time {
-            return Ok(0);
-        }
-        // 下单
-        let buy_order_id = exchange.place_okx_order(&self.symbol, &"buy".to_string(), &"limit".to_string(), &order_dict.buy_price.to_string(), &order_dict.order_amount.to_string()).await.unwrap();
-        let order = OrderInfo{
-            id: buy_order_id,
-            sell_price: order_dict.sell_price,
-            time_num: now_time
-        };
-        eprintln!("buy_order: {:?}", order);
-        self.order_info_list.push(order);
-        self.last_buy_time = now_time;
-
-        return Ok(0);
-    }
-
     fn calc_params(&self, k_line_data: &Vec<Record>, balance_info: Account, max_spread: f64) -> (f64, Vec<RiskLevel>){
         // 计算最近20根K线的标准差
         // 3.获取标准差数组的最后一个值
@@ -292,13 +191,88 @@ impl Bot {
     }
 
 
-}
+    async fn do_logic(&mut self, exchange: &Exchange) -> Result<i8, io::Error>{
+        if self.spread_list.len() < self.spread_list_limit {
+            println!("等待spread初始化:{}/{}.", self.spread_list.len(), self.spread_list_limit);
+
+            return Ok(1_i8)
+        }
+
+        // 使用 max_by 方法和 partial_cmp 进行比较
+        let max_option = self.spread_list.iter().max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
+        match max_option {
+            Some(_max_value) => {  },
+            None => eprintln!("列表为空"),
+        }
+        let max_spread =  *max_option.unwrap();
+
+        // 1.获取账户信息
+        let balance_info = match exchange.get_okx_account(&self.symbol).await {
+            Ok(m) => m,
+            Err(error) => {
+                return Err(error);
+            }
+        };
+        println!("info: {:?}", balance_info);
+        // 取消超时订单
+        self.order_list_deal(&balance_info.stocks, exchange).await.expect("订单信息处理异常!");
 
+        // 2.获取最新k线
+        let k_line_data = match  exchange.get_binance_klines(&self.symbol.to_string(), &self.short_interval, &200).await {
+            Ok(m) => m,
+            Err(error) => {
+                return Err(error);
+            }
+        };
+        // rl_list 计算
+        let (q, rl_list) = self.calc_params(&k_line_data, balance_info, max_spread);
 
+        let rate :f64 = (q/self.quantity_max) * 100.0;
+        let index_f :f64 = rl_list.len().to_f64().unwrap() / 100.0 * rate;
+        let index = index_f.round().to_usize().unwrap();
+
+        let order_amount = rl_list.get(index).unwrap().order_amount;
+        let order_dict :OrderDict = OrderDict{
+            order_amount,
+            buy_price: truncate_decimal_places(rl_list.get(index).unwrap().bid, self.price_decimal_places),
+            sell_price: truncate_decimal_places(rl_list.get(index).unwrap().ask, self.price_decimal_places)
+        };
+
+        let now_time = OffsetDateTime::now_utc().unix_timestamp();
+        // 检测交易间隔,发起交易
+        if self.order_info_list.len() > 0 || self.last_buy_time + self.buy_time_limit > now_time {
+            return Ok(0);
+        }
+        // 下单
+        let buy_order_id = exchange.place_okx_order(&self.symbol, &"buy".to_string(), &"limit".to_string(), &order_dict.buy_price.to_string(), &order_dict.order_amount.to_string()).await.unwrap();
+        let order = OrderInfo{
+            id: buy_order_id,
+            sell_price: order_dict.sell_price,
+            time_num: now_time
+        };
+        eprintln!("buy_order: {:?}", order);
+        self.order_info_list.push(order);
+        self.last_buy_time = now_time;
+
+        return Ok(0);
+    }
+
+    fn depth_handler(&self, depth: Depth) {
+        println!("{:?}", depth)
+
+        // let (spread, mid_price, ask, bid) = get_spread(depth);
+        // self.mid_price = mid_price;
+        // self.ask = ask;
+        // self.bid = bid;
+        // if self.spread_list.len() > self.spread_list_limit {
+        //     self.spread_list.remove(0);
+        // }
+        // self.spread_list.push(spread);
+    }
+}
 
 #[tokio::main]
 async fn main() {
-
     let spread_list:Vec<f64> = Vec::new();
     // 币对-获取深度信息参数
     let symbol:String = "BTC_USDT".to_string();
@@ -323,10 +297,71 @@ async fn main() {
     let cancel_time_limit = 30;
     // 价格小数位数
     let price_decimal_places = 2;
-    // spread_list_limit 长度限制
+    // spread_list 长度限制
     let spread_list_limit = 10;
 
-    let mut bot = Bot::new(spread_list, symbol, limit, short_interval, rl_start, rl_end, quantity_max, amount_decimal_places, order_info_list, last_buy_time, buy_time_limit, cancel_time_limit, price_decimal_places, spread_list_limit);
-    bot.start().await;
 
-}
+    // 使用std::env::var函数获取环境变量的值
+    let okx_access_key= env::var("okx_access_key").unwrap();
+    let okx_secret_key= env::var("okx_secret_key").unwrap();
+    let okx_passphrase= env::var("okx_passphrase").unwrap();
+    let exchange:Exchange = Exchange::new(okx_access_key, okx_secret_key, okx_passphrase);
+
+    let bot = Bot::new(spread_list, symbol.clone(), limit, short_interval, rl_start, rl_end, quantity_max, amount_decimal_places, order_info_list, last_buy_time, buy_time_limit, cancel_time_limit, price_decimal_places, spread_list_limit);
+    let bot_arc = Arc::new(Mutex::new(bot));
+    let bot_binance_depth_arc = Arc::clone(&bot_arc);
+
+    // bot主线程
+    let bot_arc_thread = tokio::spawn(async move {
+        let okx_access_key= env::var("okx_access_key").unwrap();
+        let okx_secret_key= env::var("okx_secret_key").unwrap();
+        let okx_passphrase= env::var("okx_passphrase").unwrap();
+
+        let exchange:Exchange = Exchange::new(okx_access_key, okx_secret_key, okx_passphrase);
+
+        loop {
+            {
+                let mut bot = bot_arc.lock().await;
+
+                match bot.do_logic(&exchange).await {
+                    Ok(m) => m,
+                    Err(error) => {
+                        eprintln!("异常出现,捕获异常: {}", error);
+                        1
+                    },
+                };
+            }
+
+            tokio::time::sleep(Duration::from_millis(100)).await;
+        }
+
+        // loop {
+        //     println!("我是bot1111111111111");
+        //
+        //     tokio::time::sleep(Duration::from_millis(100)).await;
+        // }
+    });
+
+    // 各种订阅信息,辅助线程
+    let subscribe_binance_depth_thread = tokio::spawn(async move {
+        exchange.get_binance_depth(&symbol, 10, bot_binance_depth_arc).await;
+        // loop {
+        //     {
+        //         let bot = bot_binance_depth_arc.lock().await;
+        //         let callback = bot.depth_handler();
+        //
+        //         exchange.get_binance_depth(&symbol, 10, callback).await;
+        //     }
+        //
+        //     // tokio::time::sleep(Duration::from_millis(100)).await;
+        // }
+
+        // loop {
+        //     println!("我是sub222222222222222方法");
+        //
+        //     tokio::time::sleep(Duration::from_millis(1)).await;
+        // }
+    });
+
+    tokio::try_join!(bot_arc_thread, subscribe_binance_depth_thread).unwrap();
+}