data_providers.rs 17 KB


  1. use std::{
  2. collections::BTreeMap,
  3. fmt::{self, Write},
  4. hash::Hash,
  5. };
  6. use ordered_float::OrderedFloat;
  7. use rust_decimal::{
  8. prelude::{FromPrimitive, ToPrimitive},
  9. Decimal,
  10. };
  11. use serde::{Deserialize, Deserializer, Serialize};
  12. use serde_json::Value;
  13. pub mod binance;
  14. pub mod bybit;
  15. pub mod fetcher;
  16. #[allow(clippy::large_enum_variant)]
  17. pub enum State {
  18. Disconnected,
  19. Connected(FragmentCollector<TokioIo<Upgraded>>),
  20. }
  21. #[derive(Debug, Clone)]
  22. pub enum Event {
  23. Connected(Connection),
  24. Disconnected(String),
  25. DepthReceived(Ticker, i64, Depth, Vec<Trade>),
  26. KlineReceived(Ticker, Kline, Timeframe),
  27. }
  28. #[derive(Debug, Clone)]
  29. pub struct Connection;
  30. #[allow(dead_code)]
  31. #[derive(thiserror::Error, Debug)]
  32. pub enum StreamError {
  33. #[error("Fetchrror: {0}")]
  34. FetchError(#[from] reqwest::Error),
  35. #[error("Parsing error: {0}")]
  36. ParseError(String),
  37. #[error("Stream error: {0}")]
  38. WebsocketError(String),
  39. #[error("Invalid request: {0}")]
  40. InvalidRequest(String),
  41. #[error("{0}")]
  42. UnknownError(String),
  43. }
  44. #[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)]
  45. pub struct TickerInfo {
  46. #[serde(rename = "tickSize")]
  47. pub tick_size: f32,
  48. pub market_type: MarketType,
  49. }
  50. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
  51. pub enum StreamType {
  52. Kline {
  53. exchange: Exchange,
  54. ticker: Ticker,
  55. timeframe: Timeframe,
  56. },
  57. DepthAndTrades {
  58. exchange: Exchange,
  59. ticker: Ticker,
  60. },
  61. None,
  62. }
  63. // data types
  64. #[derive(Debug, Clone, Copy, Default)]
  65. struct Order {
  66. price: f32,
  67. qty: f32,
  68. }
  69. impl<'de> Deserialize<'de> for Order {
  70. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  71. where
  72. D: Deserializer<'de>,
  73. {
  74. let arr: Vec<&str> = Vec::<&str>::deserialize(deserializer)?;
  75. let price: f32 = arr[0].parse::<f32>().map_err(serde::de::Error::custom)?;
  76. let qty: f32 = arr[1].parse::<f32>().map_err(serde::de::Error::custom)?;
  77. Ok(Order { price, qty })
  78. }
  79. }
  80. #[derive(Debug, Clone, Default)]
  81. pub struct Depth {
  82. pub bids: BTreeMap<OrderedFloat<f32>, f32>,
  83. pub asks: BTreeMap<OrderedFloat<f32>, f32>,
  84. }
  85. #[derive(Debug, Clone, Default)]
  86. struct VecLocalDepthCache {
  87. last_update_id: i64,
  88. time: i64,
  89. bids: Vec<Order>,
  90. asks: Vec<Order>,
  91. }
  92. #[derive(Debug, Clone, Default)]
  93. struct LocalDepthCache {
  94. last_update_id: i64,
  95. time: i64,
  96. bids: BTreeMap<OrderedFloat<f32>, f32>,
  97. asks: BTreeMap<OrderedFloat<f32>, f32>,
  98. }
  99. impl LocalDepthCache {
  100. fn new() -> Self {
  101. Self {
  102. last_update_id: 0,
  103. time: 0,
  104. bids: BTreeMap::new(),
  105. asks: BTreeMap::new(),
  106. }
  107. }
  108. fn fetched(&mut self, new_depth: &VecLocalDepthCache) {
  109. self.last_update_id = new_depth.last_update_id;
  110. self.time = new_depth.time;
  111. self.bids.clear();
  112. new_depth.bids.iter().for_each(|order| {
  113. self.bids.insert(OrderedFloat(order.price), order.qty);
  114. });
  115. self.asks.clear();
  116. new_depth.asks.iter().for_each(|order| {
  117. self.asks.insert(OrderedFloat(order.price), order.qty);
  118. });
  119. }
  120. fn update_depth_cache(&mut self, new_depth: &VecLocalDepthCache) {
  121. self.last_update_id = new_depth.last_update_id;
  122. self.time = new_depth.time;
  123. new_depth.bids.iter().for_each(|order| {
  124. if order.qty == 0.0 {
  125. self.bids.remove((&order.price).into());
  126. } else {
  127. self.bids.insert(OrderedFloat(order.price), order.qty);
  128. }
  129. });
  130. new_depth.asks.iter().for_each(|order| {
  131. if order.qty == 0.0 {
  132. self.asks.remove((&order.price).into());
  133. } else {
  134. self.asks.insert(OrderedFloat(order.price), order.qty);
  135. }
  136. });
  137. }
  138. fn get_depth(&self) -> Depth {
  139. Depth {
  140. bids: self.bids.clone(),
  141. asks: self.asks.clone(),
  142. }
  143. }
  144. fn get_fetch_id(&self) -> i64 {
  145. self.last_update_id
  146. }
  147. }
  148. #[derive(Default, Debug, Clone, Copy, Deserialize)]
  149. pub struct Trade {
  150. pub time: i64,
  151. #[serde(deserialize_with = "bool_from_int")]
  152. pub is_sell: bool,
  153. pub price: f32,
  154. pub qty: f32,
  155. }
  156. #[derive(Default, Debug, Clone, Copy)]
  157. pub struct Kline {
  158. pub time: u64,
  159. pub open: f32,
  160. pub high: f32,
  161. pub low: f32,
  162. pub close: f32,
  163. pub volume: (f32, f32),
  164. }
  165. #[derive(Debug, Clone, Copy, Deserialize, Serialize)]
  166. pub struct TickerStats {
  167. pub mark_price: f32,
  168. pub daily_price_chg: f32,
  169. pub daily_volume: f32,
  170. }
  171. #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
  172. pub struct TickMultiplier(pub u16);
  173. impl std::fmt::Display for TickMultiplier {
  174. fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
  175. write!(f, "{}x", self.0)
  176. }
  177. }
  178. impl TickMultiplier {
  179. pub const ALL: [TickMultiplier; 8] = [
  180. TickMultiplier(1),
  181. TickMultiplier(2),
  182. TickMultiplier(5),
  183. TickMultiplier(10),
  184. TickMultiplier(25),
  185. TickMultiplier(50),
  186. TickMultiplier(100),
  187. TickMultiplier(200),
  188. ];
  189. /// Returns the final tick size after applying the user selected multiplier
  190. ///
  191. /// Usually used for price steps in chart scales
  192. pub fn multiply_with_min_tick_size(&self, ticker_info: TickerInfo) -> f32 {
  193. let min_tick_size = ticker_info.tick_size;
  194. let multiplier = if let Some(m) = Decimal::from_f32(f32::from(self.0)) {
  195. m
  196. } else {
  197. log::error!("Failed to convert multiplier: {}", self.0);
  198. return f32::from(self.0) * min_tick_size;
  199. };
  200. let decimal_min_tick_size = if let Some(d) = Decimal::from_f32(min_tick_size) {
  201. d
  202. } else {
  203. log::error!("Failed to convert min_tick_size: {}", min_tick_size);
  204. return f32::from(self.0) * min_tick_size;
  205. };
  206. let normalized = multiplier * decimal_min_tick_size.normalize();
  207. if let Some(tick_size) = normalized.to_f32() {
  208. let decimal_places = calculate_decimal_places(min_tick_size);
  209. round_to_decimal_places(tick_size, decimal_places)
  210. } else {
  211. log::error!("Failed to calculate tick size for multiplier: {}", self.0);
  212. f32::from(self.0) * min_tick_size
  213. }
  214. }
  215. }
  216. // ticksize rounding helpers
  217. fn calculate_decimal_places(value: f32) -> u32 {
  218. let s = value.to_string();
  219. if let Some(decimal_pos) = s.find('.') {
  220. (s.len() - decimal_pos - 1) as u32
  221. } else {
  222. 0
  223. }
  224. }
  225. fn round_to_decimal_places(value: f32, places: u32) -> f32 {
  226. let factor = 10.0f32.powi(places as i32);
  227. (value * factor).round() / factor
  228. }
  229. // connection types
  230. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
  231. pub enum Exchange {
  232. BinanceFutures,
  233. BinanceSpot,
  234. BybitLinear,
  235. BybitSpot,
  236. }
  237. impl std::fmt::Display for Exchange {
  238. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  239. write!(
  240. f,
  241. "{}",
  242. match self {
  243. Exchange::BinanceFutures => "Binance Futures",
  244. Exchange::BinanceSpot => "Binance Spot",
  245. Exchange::BybitLinear => "Bybit Linear",
  246. Exchange::BybitSpot => "Bybit Spot",
  247. }
  248. )
  249. }
  250. }
  251. impl Exchange {
  252. pub const MARKET_TYPES: [(Exchange, MarketType); 4] = [
  253. (Exchange::BinanceFutures, MarketType::LinearPerps),
  254. (Exchange::BybitLinear, MarketType::LinearPerps),
  255. (Exchange::BinanceSpot, MarketType::Spot),
  256. (Exchange::BybitSpot, MarketType::Spot),
  257. ];
  258. }
  259. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
  260. pub enum MarketType {
  261. Spot,
  262. LinearPerps,
  263. }
  264. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
  265. pub struct Ticker {
  266. data: [u64; 2],
  267. len: u8,
  268. market_type: MarketType,
  269. }
  270. impl Default for Ticker {
  271. fn default() -> Self {
  272. Ticker::new("", MarketType::Spot)
  273. }
  274. }
  275. impl Ticker {
  276. pub fn new<S: AsRef<str>>(ticker: S, market_type: MarketType) -> Self {
  277. let ticker = ticker.as_ref();
  278. let base_len = ticker.len();
  279. assert!(base_len <= 20, "Ticker too long");
  280. assert!(
  281. ticker.chars().all(|c| c.is_ascii_alphanumeric()),
  282. "Invalid character in ticker: {ticker:?}"
  283. );
  284. let mut data = [0u64; 2];
  285. let mut len = 0;
  286. for (i, c) in ticker.bytes().enumerate() {
  287. let value = match c {
  288. b'0'..=b'9' => c - b'0',
  289. b'A'..=b'Z' => c - b'A' + 10,
  290. _ => unreachable!(),
  291. };
  292. let shift = (i % 10) * 6;
  293. data[i / 10] |= u64::from(value) << shift;
  294. len += 1;
  295. }
  296. Ticker { data, len, market_type }
  297. }
  298. pub fn get_string(&self) -> (String, MarketType) {
  299. let mut result = String::with_capacity(self.len as usize);
  300. for i in 0..self.len {
  301. let value = (self.data[i as usize / 10] >> ((i % 10) * 6)) & 0x3F;
  302. let c = match value {
  303. 0..=9 => (b'0' + value as u8) as char,
  304. 10..=35 => (b'A' + (value as u8 - 10)) as char,
  305. _ => unreachable!(),
  306. };
  307. result.push(c);
  308. }
  309. (result, self.market_type)
  310. }
  311. }
  312. impl fmt::Display for Ticker {
  313. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  314. // Direct formatting without intermediate String allocation
  315. for i in 0..self.len {
  316. let value = (self.data[i as usize / 10] >> ((i % 10) * 6)) & 0x3F;
  317. let c = match value {
  318. 0..=9 => (b'0' + value as u8) as char,
  319. 10..=35 => (b'A' + (value as u8 - 10)) as char,
  320. _ => unreachable!(),
  321. };
  322. f.write_char(c)?;
  323. }
  324. Ok(())
  325. }
  326. }
  327. impl From<(String, MarketType)> for Ticker {
  328. fn from((s, market_type): (String, MarketType)) -> Self {
  329. Ticker::new(s, market_type)
  330. }
  331. }
  332. impl From<(&str, MarketType)> for Ticker {
  333. fn from((s, market_type): (&str, MarketType)) -> Self {
  334. Ticker::new(s, market_type)
  335. }
  336. }
  337. impl std::fmt::Display for Timeframe {
  338. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  339. write!(
  340. f,
  341. "{}",
  342. match self {
  343. Timeframe::M1 => "1m",
  344. Timeframe::M3 => "3m",
  345. Timeframe::M5 => "5m",
  346. Timeframe::M15 => "15m",
  347. Timeframe::M30 => "30m",
  348. Timeframe::H1 => "1h",
  349. Timeframe::H2 => "2h",
  350. Timeframe::H4 => "4h",
  351. }
  352. )
  353. }
  354. }
  355. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
  356. pub enum Timeframe {
  357. M1,
  358. M3,
  359. M5,
  360. M15,
  361. M30,
  362. H1,
  363. H2,
  364. H4,
  365. }
  366. impl Timeframe {
  367. pub const ALL: [Timeframe; 8] = [
  368. Timeframe::M1,
  369. Timeframe::M3,
  370. Timeframe::M5,
  371. Timeframe::M15,
  372. Timeframe::M30,
  373. Timeframe::H1,
  374. Timeframe::H2,
  375. Timeframe::H4,
  376. ];
  377. pub fn to_minutes(self) -> u16 {
  378. match self {
  379. Timeframe::M1 => 1,
  380. Timeframe::M3 => 3,
  381. Timeframe::M5 => 5,
  382. Timeframe::M15 => 15,
  383. Timeframe::M30 => 30,
  384. Timeframe::H1 => 60,
  385. Timeframe::H2 => 120,
  386. Timeframe::H4 => 240,
  387. }
  388. }
  389. pub fn to_milliseconds(self) -> u64 {
  390. u64::from(self.to_minutes()) * 60_000
  391. }
  392. }
  393. fn bool_from_int<'de, D>(deserializer: D) -> Result<bool, D::Error>
  394. where
  395. D: Deserializer<'de>,
  396. {
  397. let value = Value::deserialize(deserializer)?;
  398. match value.as_i64() {
  399. Some(0) => Ok(false),
  400. Some(1) => Ok(true),
  401. _ => Err(serde::de::Error::custom("expected 0 or 1")),
  402. }
  403. }
  404. fn deserialize_string_to_f32<'de, D>(deserializer: D) -> Result<f32, D::Error>
  405. where
  406. D: serde::Deserializer<'de>,
  407. {
  408. let s: String = serde::Deserialize::deserialize(deserializer)?;
  409. s.parse::<f32>().map_err(serde::de::Error::custom)
  410. }
  411. fn deserialize_string_to_i64<'de, D>(deserializer: D) -> Result<i64, D::Error>
  412. where
  413. D: serde::Deserializer<'de>,
  414. {
  415. let s: String = serde::Deserialize::deserialize(deserializer)?;
  416. s.parse::<i64>().map_err(serde::de::Error::custom)
  417. }
  418. #[derive(Debug, Clone, Copy, PartialEq)]
  419. pub struct OpenInterest {
  420. pub time: i64,
  421. pub value: f32,
  422. }
  423. // other helpers
  424. pub fn format_with_commas(num: f32) -> String {
  425. let s = format!("{num:.0}");
  426. // Handle special case for small numbers
  427. if s.len() <= 4 && s.starts_with('-') {
  428. return s; // Return as-is if it's a small negative number
  429. }
  430. let mut result = String::with_capacity(s.len() + (s.len() - 1) / 3);
  431. let (sign, digits) = if s.starts_with('-') {
  432. ("-", &s[1..]) // Split into sign and digits
  433. } else {
  434. ("", &s[..])
  435. };
  436. let mut i = digits.len();
  437. while i > 0 {
  438. if !result.is_empty() {
  439. result.insert(0, ',');
  440. }
  441. let start = if i >= 3 { i - 3 } else { 0 };
  442. result.insert_str(0, &digits[start..i]);
  443. i = start;
  444. }
  445. // Add sign at the start if negative
  446. if !sign.is_empty() {
  447. result.insert_str(0, sign);
  448. }
  449. result
  450. }
  451. // websocket
  452. use bytes::Bytes;
  453. use tokio::net::TcpStream;
  454. use http_body_util::Empty;
  455. use hyper_util::rt::TokioIo;
  456. use fastwebsockets::FragmentCollector;
  457. use hyper::{
  458. header::{CONNECTION, UPGRADE},
  459. upgrade::Upgraded,
  460. Request,
  461. };
  462. use tokio_rustls::{
  463. rustls::{ClientConfig, OwnedTrustAnchor},
  464. TlsConnector,
  465. };
  466. struct SpawnExecutor;
  467. impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
  468. where
  469. Fut: std::future::Future + Send + 'static,
  470. Fut::Output: Send + 'static,
  471. {
  472. fn execute(&self, fut: Fut) {
  473. tokio::task::spawn(fut);
  474. }
  475. }
  476. pub fn tls_connector() -> Result<TlsConnector, StreamError> {
  477. let mut root_store = tokio_rustls::rustls::RootCertStore::empty();
  478. root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
  479. OwnedTrustAnchor::from_subject_spki_name_constraints(
  480. ta.subject,
  481. ta.spki,
  482. ta.name_constraints,
  483. )
  484. }));
  485. let config = ClientConfig::builder()
  486. .with_safe_defaults()
  487. .with_root_certificates(root_store)
  488. .with_no_client_auth();
  489. Ok(TlsConnector::from(std::sync::Arc::new(config)))
  490. }
  491. async fn setup_tcp_connection(domain: &str) -> Result<TcpStream, StreamError> {
  492. let addr = format!("{domain}:443");
  493. TcpStream::connect(&addr)
  494. .await
  495. .map_err(|e| StreamError::WebsocketError(e.to_string()))
  496. }
  497. async fn setup_tls_connection(
  498. domain: &str,
  499. tcp_stream: TcpStream,
  500. ) -> Result<tokio_rustls::client::TlsStream<TcpStream>, StreamError> {
  501. let tls_connector: TlsConnector = tls_connector()?;
  502. let domain: tokio_rustls::rustls::ServerName =
  503. tokio_rustls::rustls::ServerName::try_from(domain)
  504. .map_err(|_| StreamError::ParseError("invalid dnsname".to_string()))?;
  505. tls_connector
  506. .connect(domain, tcp_stream)
  507. .await
  508. .map_err(|e| StreamError::WebsocketError(e.to_string()))
  509. }
  510. async fn setup_websocket_connection(
  511. domain: &str,
  512. tls_stream: tokio_rustls::client::TlsStream<TcpStream>,
  513. url: &str,
  514. ) -> Result<FragmentCollector<TokioIo<Upgraded>>, StreamError> {
  515. let req: Request<Empty<Bytes>> = Request::builder()
  516. .method("GET")
  517. .uri(url)
  518. .header("Host", domain)
  519. .header(UPGRADE, "websocket")
  520. .header(CONNECTION, "upgrade")
  521. .header(
  522. "Sec-WebSocket-Key",
  523. fastwebsockets::handshake::generate_key(),
  524. )
  525. .header("Sec-WebSocket-Version", "13")
  526. .body(Empty::<Bytes>::new())
  527. .map_err(|e| StreamError::WebsocketError(e.to_string()))?;
  528. let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, tls_stream)
  529. .await
  530. .map_err(|e| StreamError::WebsocketError(e.to_string()))?;
  531. Ok(FragmentCollector::new(ws))
  532. }
  533. #[allow(unused_imports)]
  534. mod tests {
  535. use super::*;
  536. #[tokio::test]
  537. async fn fetch_bybit_tickers_with_rate_limits() -> Result<(), StreamError> {
  538. let url = "https://api.bybit.com/v5/market/tickers?category=spot".to_string();
  539. let response = reqwest::get(&url).await.map_err(StreamError::FetchError)?;
  540. println!("{:?}", response.headers());
  541. let _text = response.text().await.map_err(StreamError::FetchError)?;
  542. Ok(())
  543. }
  544. #[tokio::test]
  545. async fn fetch_binance_tickers_with_rate_limits() -> Result<(), StreamError> {
  546. let url = "https://fapi.binance.com/fapi/v1/ticker/24hr".to_string();
  547. let response = reqwest::get(&url).await.map_err(StreamError::FetchError)?;
  548. println!("{:?}", response.headers());
  549. let _text = response.text().await.map_err(StreamError::FetchError)?;
  550. Ok(())
  551. }
  552. }