Przeglądaj źródła

配置文件化

skyfffire 2 tygodni temu
rodzic
commit
c9b4cbd1b9

+ 2 - 1
.gitignore

@@ -1,3 +1,4 @@
 /target
 /logs
-Cargo.lock
+Cargo.lock
+config.toml

+ 7 - 18
Cargo.toml

@@ -13,10 +13,6 @@ tokio = { version = "1", features = ["full"] }
 # - "tokio-native-tls": 与 tokio 集成,支持 HTTPS (SSL/TLS),使用系统的 TLS 实现
 reqwest = { version = "0.11", features = ["json", "tokio-native-tls"] }
 
-
-ring = "0.16.20"
-base64 = "0.13"
-
 futures-channel = "0.3.28"
 
 # WebSocket 客户端,基于 tokio 构建,用于订阅 K 线和深度
@@ -26,9 +22,6 @@ tokio-tungstenite= { git = "https://github.com/skyfffire/tokio-tungstenite-proxy
 # 包含了 Stream 的一些方法,例如 split 用于分离 WebSocket stream 的读写端
 futures-util = "0.3"
 
-# URL 解析和构建库,在处理交易所端点或配置代理时可能会用到
-url = "2"
-
 # 序列化和反序列化框架,用于将 Rust 结构体与各种数据格式相互转换
 # - "derive": 启用 derive 宏,方便自动实现 Serialize 和 Deserialize trait
 serde = { version = "1", features = ["derive"] }
@@ -36,6 +29,9 @@ serde = { version = "1", features = ["derive"] }
 # serde 的 JSON 实现,用于处理交易所 API 的 JSON 数据和应用配置(如果使用 JSON 格式)
 serde_json = "1"
 
+# 读取toml配置
+toml = "0.8"
+
 # 日志和诊断框架,推荐使用 tracing,功能强大且适合异步应用
 tracing = "0.1"
 
@@ -52,24 +48,19 @@ tracing-appender = "0.2"
 chrono = "0.4"
 # 时区数据库,用于获取 Asia/Shanghai 时区
 chrono-tz = "0.8"
-# 确保 time crate 的特性被启用
-time = { version = "0.3", features = ["macros", "formatting", "parsing", "local-offset"] }
 
 # 简化错误处理的库,方便快速构建可链式调用的错误
 anyhow = "1"
 
-# 用于定义自定义错误类型的库,与 anyhow 配合使用,让错误更具语义
-thiserror = "1"
-
 # 异步 SQL 数据库客户端,选择支持 tokio 和 SQLite 的版本,用于存储配置等
 # - "runtime-tokio": 使用 tokio 作为异步运行时
 # - "sqlite": 启用 SQLite 驱动
 # - "macros": 启用宏支持,用于 compile-time query checking (强烈推荐)
 # - "offline": 用于离线模式下的宏检查 (与 "macros" 配合使用)
-sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "macros"] }
+# sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "macros"] }
 backtrace = "0.3.74"
 
-prost = "0.11"
+# prost = "0.11"
 hex = "0.4.3"
 
 ##代替f64避免精度丢失
@@ -82,13 +73,11 @@ sha2 = "0.10.8"
 starknet = { git = "https://github.com/xJonathanLEI/starknet-rs", tag = "starknet/v0.17.0" }
 starknet-crypto = "0.8.1"  # `starknet` crate doesn't re-export `PoseidonHasher`
 
-# 随机数发生工具
-rand = "0.8"
-uuid = { version = "1.0", features = ["v4", "fast-rng", "macro-diagnostics"] }
+# 确保 time crate 的特性被启用
+time = { version = "0.3", features = ["macros", "formatting", "parsing", "local-offset"] }
 
 # =======================================================
 # 以下是一些在开发过程中可能会用到的devDependencies,只用于开发和测试,不包含在最终发布版本中
 [dev-dependencies]
 
 [build-dependencies]
-prost-build = "0.11" # 或者最新版本

+ 18 - 0
config.example.toml

@@ -0,0 +1,18 @@
+# ========================================
+# 配置文件示例
+# 复制此文件为 config.toml 并填入真实配置
+# ========================================
+
+[strategy]
+order_quantity = "0.001"                            # 下单数量
+market = "BTC-USD"                                  # 交易对
+stop_loss_ratio = 0.02                              # 0.02代表2%,吊灯止损
+
+[network]
+is_testnet = false                                  # 启用测试网
+
+[account]
+api_key = "YOUR_API_KEY"
+stark_public_key = "0xYOUR_STARK_PUBLIC_KEY"
+stark_private_key = "0xYOUR_STARK_PRIVATE_KEY"
+vault_number = 0

+ 2 - 1
src/exchange/extended_account.rs

@@ -1,6 +1,7 @@
 use rust_decimal::prelude::Zero;
+use serde::Deserialize;
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, Deserialize)]
 #[allow(dead_code)]
 pub struct ExtendedAccount {
     pub api_key: String,

+ 10 - 7
src/exchange/extended_rest_client.rs

@@ -10,7 +10,7 @@ use serde_json::{json, Value};
 use starknet::core::types::Felt;
 use tracing::{warn};
 use crate::exchange::extended_account::ExtendedAccount;
-use crate::utils::lib::{get_order_hash, sign_message};
+use crate::utils::starknet_lib::{get_order_hash, sign_message};
 use crate::utils::response::Response;
 use crate::utils::rest_utils::RestUtils;
 
@@ -532,18 +532,21 @@ mod tests {
     use tracing::{info, warn};
     use crate::exchange::extended_account::ExtendedAccount;
     use crate::exchange::extended_rest_client::ExtendedRestClient;
+    use crate::utils::config::Config;
     use crate::utils::log_setup::setup_logging;
 
     async fn get_client() -> ExtendedRestClient {
+        let config = Config::load().unwrap();
+        
         let tag = "Extended";
-        let market = "BTC-USD";
+        let market = config.strategy.market.as_str();
         let account = ExtendedAccount::new(
-            "9ae4030902ab469a1bae8a90464e2e91",
-            "0x71e16e49b717b851ced8347cf0dfa8f490bfb826323b9af624a66285dc99672",
-            "0x47cdde8952945c13460f9129644eade096100810fba59de05452b34aacecff6",
-            220844,
+            config.account.api_key.as_str(),
+            config.account.stark_public_key.as_str(),
+            config.account.stark_private_key.as_str(),
+            config.account.vault_number,
         );
-        let is_testnet = false;
+        let is_testnet = config.network.is_testnet;
 
         let client_result = ExtendedRestClient::new(tag, Some(account), market, is_testnet).await;
 

+ 23 - 11
src/main.rs

@@ -17,10 +17,12 @@ use crate::exchange::extended_account::ExtendedAccount;
 use crate::exchange::extended_rest_client::ExtendedRestClient;
 use crate::exchange::extended_stream_client::ExtendedStreamClient;
 use crate::strategy::Strategy;
+use crate::utils::config::Config;
 use crate::utils::response::Response;
 
 #[tokio::main]
 async fn main() {
+    // 日志初始化
     let _guards = log_setup::setup_logging().unwrap();
 
     // 主进程控制
@@ -44,10 +46,19 @@ async fn main() {
         panic_running.store(false, Ordering::Relaxed);
     }));
 
+    // 配置文件
+    let config_result = Config::load();
+    let config = match config_result {
+        Ok(config) => config,
+        Err(error) => {
+            panic!("Configuration error: {}", error);
+        }
+    };
+
     // ---- 优雅停机处理 (示例: SIGINT/Ctrl+C) ----
     //注意:Windows上可能不支持所有信号,SIGINT通常可用
     let r = running.clone(); // 克隆 Arc 用于 SIGHUP/SIGTERM/SIGINT 处理
-    tokio::spawn(async move {
+    spawn(async move {
         tokio::signal::ctrl_c().await.expect("设置 Ctrl+C 处理器失败");
         warn!("接收到退出信号 (Ctrl+C)... 开始关闭.");
         r.store(false, Ordering::Relaxed);
@@ -56,10 +67,11 @@ async fn main() {
     // ---- 运行核心订阅逻辑 ----
     info!("==================================== 应用程序启动 =======================================");
     let task_running = running.clone();
+    let config_clone = config.clone();
     // 启动一个后台任务来执行订阅和数据处理
-    let subscription_handle = tokio::spawn(async move {
+    let subscription_handle = spawn(async move {
         // 运行获取交易对和订阅 K 线的函数
-        if let Err(e) = run_extended_subscriptions(task_running.clone()).await {
+        if let Err(e) = run_extended_subscriptions(task_running.clone(), &config_clone).await {
             error!("运行 Extended 订阅任务失败: {:?}", e);
             task_running.store(false, Ordering::Relaxed); // 如果启动失败,也设置停止标志
         }
@@ -93,15 +105,15 @@ async fn main() {
 /// * `running` - 用于控制程序是否继续运行的原子布尔值 (Arc 包裹)
 ///
 /// # Returns
-pub async fn run_extended_subscriptions(running: Arc<AtomicBool>) -> Result<()> {
-    let market = "BTC-USD";
+pub async fn run_extended_subscriptions(running: Arc<AtomicBool>, config: &Config) -> Result<()> {
+    let market = config.strategy.market.as_str();
     let account = ExtendedAccount::new(
-        "9ae4030902ab469a1bae8a90464e2e91",
-        "0x71e16e49b717b851ced8347cf0dfa8f490bfb826323b9af624a66285dc99672",
-        "0x47cdde8952945c13460f9129644eade096100810fba59de05452b34aacecff6",
-        220844,
+        config.account.api_key.as_str(),
+        config.account.stark_public_key.as_str(),
+        config.account.stark_private_key.as_str(),
+        config.account.vault_number,
     );
-    let is_testnet = false;
+    let is_testnet = config.network.is_testnet;
 
     // 订阅数据的客户端
     let stream_client_list = vec![
@@ -118,7 +130,7 @@ pub async fn run_extended_subscriptions(running: Arc<AtomicBool>) -> Result<()>
     let data_manager_am = Arc::new(Mutex::new(data_manager));
 
     // 策略执行
-    let strategy = Strategy::new(rest_client_am.clone());
+    let strategy = Strategy::new(rest_client_am.clone(), &config.strategy);
     let strategy_am = Arc::new(Mutex::new(strategy));
 
     // 异步去订阅、并阻塞

+ 3 - 3
src/strategy.rs

@@ -3,12 +3,12 @@ use std::sync::Arc;
 use anyhow::{anyhow, bail, Result};
 use rust_decimal::Decimal;
 use std::time::{Duration, Instant};
-use rust_decimal_macros::dec;
 use tokio::sync::Mutex;
 use tokio::time::sleep;
 use tracing::{info, warn};
 use crate::data_manager::DataManager;
 use crate::exchange::extended_rest_client::ExtendedRestClient;
+use crate::utils::config::StrategyConfig;
 use crate::utils::response::Response;
 
 #[derive(Debug, Clone, PartialEq)]
@@ -45,10 +45,10 @@ pub struct Strategy {
 }
 
 impl Strategy {
-    pub fn new(client_am: Arc<Mutex<ExtendedRestClient>>) -> Strategy {
+    pub fn new(client_am: Arc<Mutex<ExtendedRestClient>>, config: &StrategyConfig) -> Strategy {
         Strategy {
             state: StrategyState::Init,
-            order_quantity: dec!(0.0001),
+            order_quantity: config.order_quantity,
             filled_quantity: Decimal::ZERO,
 
             rest_client: client_am,

+ 156 - 0
src/utils/config.rs

@@ -0,0 +1,156 @@
+use anyhow::{Context, Result};
+use rust_decimal::Decimal;
+use serde::Deserialize;
+use std::{fs, str::FromStr};
+use crate::exchange::extended_account::ExtendedAccount;
+
+/// 完整配置结构
+#[derive(Debug, Deserialize, Clone)]
+pub struct Config {
+    pub strategy: StrategyConfig,
+    pub network: NetworkConfig,
+    pub account: ExtendedAccount,
+}
+
+/// 策略配置
+#[derive(Debug, Deserialize, Clone)]
+pub struct StrategyConfig {
+    /// 下单数量
+    #[serde(deserialize_with = "deserialize_decimal")]
+    pub order_quantity: Decimal,
+
+    /// 交易市场
+    pub market: String,
+
+    /// 止损比例
+    #[serde(deserialize_with = "deserialize_decimal")]
+    pub stop_loss_ratio: Decimal,
+}
+
+/// 网络配置
+#[derive(Debug, Deserialize, Clone)]
+pub struct NetworkConfig {
+    /// 是否使用测试网
+    pub is_testnet: bool,
+}
+
+impl Config {
+    /// 从 config.toml 加载配置
+    pub fn load() -> Result<Self> {
+        Self::load_from_file("config.toml")
+    }
+
+    /// 从指定文件加载配置
+    pub fn load_from_file(path: &str) -> Result<Self> {
+        let content = fs::read_to_string(path)
+            .with_context(|| format!("无法读取配置文件: {}", path))?;
+
+        let config: Config = toml::from_str(&content)
+            .with_context(|| format!("无法解析配置文件: {}", path))?;
+
+        // 验证配置
+        config.validate()?;
+
+        Ok(config)
+    }
+
+    /// 验证配置有效性
+    fn validate(&self) -> Result<()> {
+        // 验证下单数量
+        if self.strategy.order_quantity <= Decimal::ZERO {
+            anyhow::bail!("order_quantity 必须大于 0");
+        }
+
+        // 验证止损比例
+        if self.strategy.stop_loss_ratio <= Decimal::ZERO
+            || self.strategy.stop_loss_ratio >= Decimal::ONE {
+            anyhow::bail!("stop_loss_ratio 必须在 0 到 1 之间");
+        }
+
+        // 验证市场名称
+        if self.strategy.market.is_empty() {
+            anyhow::bail!("market 不能为空");
+        }
+
+        // 验证账户配置
+        if self.account.api_key.is_empty()
+            || self.account.api_key == "YOUR_API_KEY" {
+            anyhow::bail!("请在 config.toml 中配置真实的 api_key");
+        }
+
+        if self.account.stark_private_key.is_empty()
+            || self.account.stark_private_key == "YOUR_STARK_PRIVATE_KEY" {
+            anyhow::bail!("请在 config.toml 中配置真实的 stark_private_key");
+        }
+
+        Ok(())
+    }
+
+    /// 获取网络类型描述
+    pub fn get_network_name(&self) -> &str {
+        if self.network.is_testnet {
+            "Testnet"
+        } else {
+            "Production"
+        }
+    }
+}
+
+/// 自定义 Decimal 反序列化
+fn deserialize_decimal<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
+where
+    D: serde::Deserializer<'de>,
+{
+    let s = String::deserialize(deserializer)?;
+    Decimal::from_str(&s).map_err(serde::de::Error::custom)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_load_config() {
+        let config = Config::load().expect("加载配置失败");
+
+        println!("配置加载成功!");
+        println!("=====================================");
+        println!("策略配置:");
+        println!("  市场: {}", config.strategy.market);
+        println!("  下单数量: {}", config.strategy.order_quantity);
+        println!("  止损比例: {}%", config.strategy.stop_loss_ratio * Decimal::from(100));
+        println!();
+        println!("网络配置:");
+        println!("  环境: {}", config.get_network_name());
+        println!();
+        println!("账户配置:");
+        println!("  API Key: {}...", &config.account.api_key[..20.min(config.account.api_key.len())]);
+        println!("  Vault Number: {}", config.account.vault_number);
+        println!("=====================================");
+    }
+
+    #[test]
+    fn test_validation() {
+        // 测试无效配置
+        let invalid_toml = r#"
+            [strategy]
+            order_quantity = "-0.001"
+            market = "BTC-USD"
+            stop_loss_ratio = "0.02"
+
+            [network]
+            is_testnet = true
+            testnet_url = "https://test.com"
+            production_url = "https://prod.com"
+
+            [account]
+            api_key = "test"
+            stark_public_key = "test"
+            stark_private_key = "test"
+            vault_number = 1
+        "#;
+
+        let config: Config = toml::from_str(invalid_toml).unwrap();
+        assert!(config.validate().is_err());
+    }
+}

+ 2 - 1
src/utils/mod.rs

@@ -4,4 +4,5 @@ pub mod stream_utils;
 pub(crate) mod response;
 pub(crate) mod proxy;
 pub mod starknet_messages;
-pub mod lib;
+pub mod starknet_lib;
+pub mod config;

+ 0 - 0
src/utils/lib.rs → src/utils/starknet_lib.rs