extended_stream_client.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. use std::sync::Arc;
  2. use std::sync::atomic::AtomicBool;
  3. use std::time::Duration;
  4. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  5. use serde_json::json;
  6. use serde_json::Value;
  7. use tokio::sync::Mutex;
  8. use tokio_tungstenite::tungstenite::{http, Message};
  9. use tracing::{error, info, trace, warn};
  10. use anyhow::Result;
  11. use chrono::Utc;
  12. use tokio_tungstenite::tungstenite::handshake::client::{generate_key, Request};
  13. use tracing_subscriber::fmt::format::json;
  14. use crate::exchange::extended_account::ExtendedAccount;
  15. use crate::utils::response::Response;
  16. use crate::utils::stream_utils::{StreamUtils, HeartbeatType};
  17. #[derive(Clone)]
  18. #[allow(dead_code)]
  19. pub struct ExtendedStreamClient {
  20. // 标签
  21. tag: String,
  22. // 地址
  23. address_url: String,
  24. // 账号
  25. account_option: Option<ExtendedAccount>,
  26. // 心跳间隔
  27. heartbeat_time: u64,
  28. }
  29. impl ExtendedStreamClient {
  30. // ============================================= 构造函数 ================================================
  31. fn new(tag: String, account_option: Option<ExtendedAccount>, subscribe_pattern: String) -> ExtendedStreamClient {
  32. let host = "wss://api.starknet.extended.exchange/stream.extended.exchange/v1/".to_string(); // mainnet
  33. // let host = "wss://api.starknet.sepolia.extended.exchange/stream.extended.exchange/v1/".to_string(); // testnet
  34. let address_url = format!("{}{}", host, subscribe_pattern);
  35. ExtendedStreamClient {
  36. tag,
  37. address_url,
  38. account_option,
  39. heartbeat_time: 1000 * 10,
  40. }
  41. }
  42. // ============================================= 订阅函数 ================================================
  43. pub fn order_books(tag: String, account_option: Option<ExtendedAccount>, symbol: String) -> ExtendedStreamClient {
  44. Self::new(tag, account_option, format!("orderbooks/{}", symbol))
  45. }
  46. // 链接
  47. pub async fn ws_connect_async<F, Future>(&mut self,
  48. is_shutdown_arc: Arc<AtomicBool>,
  49. handle_function: F,
  50. write_tx_am: &Arc<Mutex<UnboundedSender<Message>>>,
  51. write_to_socket_rx: UnboundedReceiver<Message>) -> Result<()>
  52. where
  53. F: Fn(Response) -> Future + Clone + Send + 'static + Sync,
  54. Future: std::future::Future<Output=()> + Send + 'static, // 确保 Fut 是一个 Future,且输出类型为 ()
  55. {
  56. let address_url = self.address_url.clone();
  57. let tag = self.tag.clone();
  58. // 自动心跳包
  59. let write_tx_clone1 = write_tx_am.clone();
  60. let heartbeat_time = self.heartbeat_time.clone();
  61. tokio::spawn(async move {
  62. let ping_obj = json!({"method":"PING"});
  63. StreamUtils::ping_pong(write_tx_clone1, HeartbeatType::Custom(ping_obj.to_string()), heartbeat_time).await;
  64. });
  65. if self.account_option.is_some() {
  66. // 登录相关
  67. }
  68. // 提取host
  69. let parsed_uri: http::Uri = address_url.parse()?;
  70. let host_domain = parsed_uri.host().ok_or("URI 缺少主机名").unwrap().to_string();
  71. let host_header_value = if let Some(port) = parsed_uri.port_u16() {
  72. // 如果端口不是默认的 80 (for ws) 或 443 (for wss),则需要包含端口
  73. // 这里只是简单地判断,更严谨的判断可以根据 scheme 来
  74. match parsed_uri.scheme_str() {
  75. Some("ws") if port == 80 => host_domain.to_string(),
  76. Some("wss") if port == 443 => host_domain.to_string(),
  77. _ => format!("{}:{}", host_domain, port), // 否则包含端口
  78. }
  79. } else {
  80. host_domain.to_string() // 没有端口或使用默认端口
  81. };
  82. // 链接
  83. let t2 = tokio::spawn(async move {
  84. let write_to_socket_rx_arc = Arc::new(Mutex::new(write_to_socket_rx));
  85. loop {
  86. // 通过构建request的方式进行ws链接,可以携带header
  87. let request = Request::builder()
  88. .method("GET")
  89. .uri(&address_url)
  90. .header("Sec-WebSocket-Key", generate_key())
  91. .header("Sec-WebSocket-Version", "13")
  92. .header("Host", host_header_value.clone())
  93. .header("User-Agent", "RustClient/1.0")
  94. .header("Upgrade", "websocket")
  95. .header("Connection", "Upgrade")
  96. .body(())
  97. .unwrap();
  98. trace!("Extended_usdt_swap socket 连接中……");
  99. StreamUtils::ws_connect_async(is_shutdown_arc.clone(), handle_function.clone(), request,
  100. false, tag.clone(), vec![], write_to_socket_rx_arc.clone(),
  101. Self::message_text, Self::message_ping, Self::message_pong, Self::message_binary).await;
  102. warn!("Extended_usdt_swap socket 断连,1s以后重连……");
  103. tokio::time::sleep(Duration::from_secs(1)).await;
  104. }
  105. });
  106. tokio::try_join!(t2)?;
  107. trace!("线程-心跳与链接-结束");
  108. Ok(())
  109. }
  110. //数据解析-Text
  111. pub fn message_text(text: String) -> Option<Response> {
  112. let mut res_data = Response::new("".to_string(), -201, "success".to_string(), Value::Null);
  113. let json_value: Value = serde_json::from_str(&text).unwrap();
  114. // info!("等待解析:{}", serde_json::to_string_pretty(&json_value).unwrap());
  115. match json_value["ts"].as_i64() {
  116. Some(ts) => {
  117. res_data.reach_time = ts;
  118. res_data.received_time = Utc::now().timestamp_millis();
  119. res_data.code = 200;
  120. res_data.data = json_value.clone();
  121. }
  122. None => {
  123. res_data.data = json_value.clone();
  124. res_data.code = -1;
  125. res_data.message = text;
  126. }
  127. }
  128. Option::from(res_data)
  129. }
  130. //数据解析-ping
  131. pub fn message_ping(_pi: Vec<u8>) -> Option<Response> {
  132. Option::from(Response::new("".to_string(), -300, "success".to_string(), Value::Null))
  133. }
  134. //数据解析-pong
  135. pub fn message_pong(_po: Vec<u8>) -> Option<Response> {
  136. Option::from(Response::new("".to_string(), -301, "success".to_string(), Value::Null))
  137. }
  138. //数据解析-二进制
  139. pub fn message_binary(po: Vec<u8>) -> Option<Response> {
  140. // info!("Received binary message ({} bytes)", po.len());
  141. // 1. 尝试用新的顶层消息结构 PublicSpotKlineV3ApiMessage 来解析 K 线数据
  142. // 根据 Topic 前缀判断依然有效,但现在是判断是否**可能**是 K 线相关消息
  143. let prefix_len = po.len().min(100);
  144. let prefix_string = String::from_utf8_lossy(&po[..prefix_len]);
  145. // if prefix_string.contains("spot@public.kline.v3.api.pb") {
  146. // // info!("通过 Topic 前缀判断为 K 线数据相关消息");
  147. //
  148. // // 尝试解析为 PublicSpotKlineV3ApiMessage
  149. // match PublicSpotKlineV3ApiMessage::decode(&po[..]) {
  150. // Ok(kline_message) => {
  151. // // info!("成功解析为顶层 K 线消息结构");
  152. // // 检查是否包含嵌套的 KlineDataV3 字段 (Tag 308)
  153. // if let Some(kline_data) = kline_message.kline_data { // 注意这里 PublicSpotKlineV3ApiMessage 的 kline_data 字段是 Option<KlineDataV3>
  154. // // info!("找到并成功访问嵌套的 KlineDataV3");
  155. // // 现在 kline_data 是 KlineDataV3 结构体,你可以使用它了!
  156. // // 填充 Response 并返回 (省略详细实现)
  157. // let response_data = Response::new(
  158. // kline_message.topic_info.clone(), // 使用解析到的 Topic 信息
  159. // 200,
  160. // "success".to_string(),
  161. // json!({
  162. // "interval": kline_data.interval,
  163. // "windowStart": kline_data.window_start, //注意 snake_case
  164. // "openingPrice": kline_data.opening_price,
  165. // "closingPrice": kline_data.closing_price,
  166. // "highestPrice": kline_data.highest_price,
  167. // "lowestPrice": kline_data.lowest_price,
  168. // "volume": kline_data.volume,
  169. // "amount": kline_data.amount,
  170. // "windowEnd": kline_data.window_end,
  171. // // 可以添加顶层字段的信息,如果需要
  172. // "topic_info": kline_message.topic_info,
  173. // "symbol": kline_message.symbol,
  174. // "id_info": kline_message.id_info,
  175. // "timestamp": kline_message.timestamp,
  176. // })
  177. // );
  178. // return Some(response_data);
  179. // } else {
  180. // info!("顶层 K 线消息结构解析成功,但未找到嵌套的 kline_data 字段 (Tag 308)");
  181. // // 这可能是一个只有顶层字段的控制消息
  182. // return Some(Response::new(
  183. // kline_message.topic_info.clone(), // 使用解析到的 Topic 信息
  184. // 200,
  185. // "OK (Control Message)".to_string(),
  186. // json!({
  187. // "topic_info": kline_message.topic_info,
  188. // "symbol": kline_message.symbol,
  189. // "id_info": kline_message.id_info,
  190. // "timestamp": kline_message.timestamp,
  191. // })
  192. // ));
  193. // }
  194. // }
  195. // Err(e) => {
  196. // error!("尝试解析为 PublicSpotKlineV3ApiMessage 失败: {:?}", e);
  197. // }
  198. // }
  199. // }
  200. //
  201. // // 2. 尝试解析深度数据 (使用新的结构体)
  202. // if prefix_string.contains("spot@public.aggre.depth.v3.api.pb") {
  203. // // info!("通过 Topic 前缀判断为深度数据");
  204. //
  205. // // 尝试解析为 PublicIncreaseDepthsV3ApiMessage (新的顶层深度消息)
  206. // match PublicIncreaseDepthsV3ApiMessage::decode(&po[..]) {
  207. // Ok(depth_message) => {
  208. // // info!("成功解析为顶层深度消息结构");
  209. //
  210. // // 检查是否包含嵌套的 depth_data 字段 (Tag 313)
  211. // if let Some(depth_data_content) = depth_message.depth_data {
  212. // // info!("找到并成功访问嵌套的 DepthDataContentV3");
  213. //
  214. // // 填充 Response 并返回
  215. // let response_data = Response::new(
  216. // depth_message.topic_info.clone(), // 使用解析到的 Topic
  217. // 200,
  218. // "success".to_string(),
  219. // json!({
  220. // // 嵌套消息内部的字段
  221. // "asks": depth_data_content.asks.into_iter().map(|item| json!({"price": item.price, "quantity": item.quantity})).collect::<Vec<_>>(),
  222. // "bids": depth_data_content.bids.into_iter().map(|item| json!({"price": item.price, "quantity": item.quantity})).collect::<Vec<_>>(),
  223. // "eventType": depth_data_content.event_type,
  224. // "version": depth_data_content.version,
  225. // "lastUpdateId": depth_data_content.last_update_id, // 新增字段
  226. //
  227. // // 顶层字段
  228. // "topic_info": depth_message.topic_info,
  229. // "symbol": depth_message.symbol,
  230. // "timestamp": depth_message.timestamp, // 新增字段
  231. // })
  232. // );
  233. // return Some(response_data);
  234. // } else {
  235. // info!("顶层深度消息结构解析成功,但未找到嵌套的 depth_data 字段 (Tag 313)");
  236. // // 处理只有顶层字段的深度相关消息
  237. // return Some(Response::new(
  238. // depth_message.topic_info.clone(),
  239. // 200, "OK (Control Message)".to_string(),
  240. // serde_json::json!({
  241. // "topic_info": depth_message.topic_info,
  242. // "symbol": depth_message.symbol,
  243. // "timestamp": depth_message.timestamp,
  244. // })
  245. // ));
  246. // }
  247. // }
  248. // Err(e) => {
  249. // error!("解析深度消息 PublicIncreaseDepthsV3ApiMessage 失败: {:?}", e);
  250. // }
  251. // }
  252. // }
  253. // 如果都不是已知的 Protobuf 类型,处理未知消息
  254. error!("无法将二进制消息解析为任何已知 Protobuf 类型, {}", prefix_string);
  255. Some(Response::new("".to_string(), 400, "无法解析未知二进制消息".to_string(), Value::Null))
  256. }
  257. }
  258. #[cfg(test)]
  259. mod tests {
  260. use std::sync::Arc;
  261. use std::sync::atomic::AtomicBool;
  262. use tokio::sync::Mutex;
  263. use tokio_tungstenite::tungstenite::Message;
  264. use tracing::info;
  265. use crate::exchange::extended_stream_client::{ExtendedStreamClient};
  266. use crate::utils::response::Response;
  267. use crate::utils::log_setup::setup_logging;
  268. #[tokio::test]
  269. async fn test_extended_ws() {
  270. let ws_running = Arc::new(AtomicBool::new(true));
  271. let (write_tx, write_rx) = futures_channel::mpsc::unbounded::<Message>();
  272. let _guard = setup_logging().unwrap();
  273. let mut ws = ExtendedStreamClient::order_books("Extended".to_string(), None, "BTC-USD".to_string());
  274. let fun = move |response: Response| {
  275. info!("{}", serde_json::to_string_pretty(&response.data).unwrap());
  276. async move {}
  277. };
  278. // 链接
  279. info!("开始链接");
  280. let write_tx_am = Arc::new(Mutex::new(write_tx));
  281. ws.ws_connect_async(ws_running, fun, &write_tx_am, write_rx)
  282. .await
  283. .expect("链接失败");
  284. }
  285. }