|
|
@@ -0,0 +1,156 @@
|
|
|
+use anyhow::{Context, Result};
|
|
|
+use rust_decimal::Decimal;
|
|
|
+use serde::Deserialize;
|
|
|
+use std::{fs, str::FromStr};
|
|
|
+use crate::exchange::extended_account::ExtendedAccount;
|
|
|
+
|
|
|
+/// 完整配置结构
|
|
|
+#[derive(Debug, Deserialize, Clone)]
|
|
|
+pub struct Config {
|
|
|
+ pub strategy: StrategyConfig,
|
|
|
+ pub network: NetworkConfig,
|
|
|
+ pub account: ExtendedAccount,
|
|
|
+}
|
|
|
+
|
|
|
+/// 策略配置
|
|
|
+#[derive(Debug, Deserialize, Clone)]
|
|
|
+pub struct StrategyConfig {
|
|
|
+ /// 下单数量
|
|
|
+ #[serde(deserialize_with = "deserialize_decimal")]
|
|
|
+ pub order_quantity: Decimal,
|
|
|
+
|
|
|
+ /// 交易市场
|
|
|
+ pub market: String,
|
|
|
+
|
|
|
+ /// 止损比例
|
|
|
+ #[serde(deserialize_with = "deserialize_decimal")]
|
|
|
+ pub stop_loss_ratio: Decimal,
|
|
|
+}
|
|
|
+
|
|
|
+/// 网络配置
|
|
|
+#[derive(Debug, Deserialize, Clone)]
|
|
|
+pub struct NetworkConfig {
|
|
|
+ /// 是否使用测试网
|
|
|
+ pub is_testnet: bool,
|
|
|
+}
|
|
|
+
|
|
|
+impl Config {
|
|
|
+ /// 从 config.toml 加载配置
|
|
|
+ pub fn load() -> Result<Self> {
|
|
|
+ Self::load_from_file("config.toml")
|
|
|
+ }
|
|
|
+
|
|
|
+ /// 从指定文件加载配置
|
|
|
+ pub fn load_from_file(path: &str) -> Result<Self> {
|
|
|
+ let content = fs::read_to_string(path)
|
|
|
+ .with_context(|| format!("无法读取配置文件: {}", path))?;
|
|
|
+
|
|
|
+ let config: Config = toml::from_str(&content)
|
|
|
+ .with_context(|| format!("无法解析配置文件: {}", path))?;
|
|
|
+
|
|
|
+ // 验证配置
|
|
|
+ config.validate()?;
|
|
|
+
|
|
|
+ Ok(config)
|
|
|
+ }
|
|
|
+
|
|
|
+ /// 验证配置有效性
|
|
|
+ fn validate(&self) -> Result<()> {
|
|
|
+ // 验证下单数量
|
|
|
+ if self.strategy.order_quantity <= Decimal::ZERO {
|
|
|
+ anyhow::bail!("order_quantity 必须大于 0");
|
|
|
+ }
|
|
|
+
|
|
|
+ // 验证止损比例
|
|
|
+ if self.strategy.stop_loss_ratio <= Decimal::ZERO
|
|
|
+ || self.strategy.stop_loss_ratio >= Decimal::ONE {
|
|
|
+ anyhow::bail!("stop_loss_ratio 必须在 0 到 1 之间");
|
|
|
+ }
|
|
|
+
|
|
|
+ // 验证市场名称
|
|
|
+ if self.strategy.market.is_empty() {
|
|
|
+ anyhow::bail!("market 不能为空");
|
|
|
+ }
|
|
|
+
|
|
|
+ // 验证账户配置
|
|
|
+ if self.account.api_key.is_empty()
|
|
|
+ || self.account.api_key == "YOUR_API_KEY" {
|
|
|
+ anyhow::bail!("请在 config.toml 中配置真实的 api_key");
|
|
|
+ }
|
|
|
+
|
|
|
+ if self.account.stark_private_key.is_empty()
|
|
|
+ || self.account.stark_private_key == "YOUR_STARK_PRIVATE_KEY" {
|
|
|
+ anyhow::bail!("请在 config.toml 中配置真实的 stark_private_key");
|
|
|
+ }
|
|
|
+
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// 获取网络类型描述
|
|
|
+ pub fn get_network_name(&self) -> &str {
|
|
|
+ if self.network.is_testnet {
|
|
|
+ "Testnet"
|
|
|
+ } else {
|
|
|
+ "Production"
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// 自定义 Decimal 反序列化
|
|
|
+fn deserialize_decimal<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
|
|
|
+where
|
|
|
+ D: serde::Deserializer<'de>,
|
|
|
+{
|
|
|
+ let s = String::deserialize(deserializer)?;
|
|
|
+ Decimal::from_str(&s).map_err(serde::de::Error::custom)
|
|
|
+}
|
|
|
+
|
|
|
+#[cfg(test)]
|
|
|
+mod tests {
|
|
|
+ use super::*;
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_load_config() {
|
|
|
+ let config = Config::load().expect("加载配置失败");
|
|
|
+
|
|
|
+ println!("配置加载成功!");
|
|
|
+ println!("=====================================");
|
|
|
+ println!("策略配置:");
|
|
|
+ println!(" 市场: {}", config.strategy.market);
|
|
|
+ println!(" 下单数量: {}", config.strategy.order_quantity);
|
|
|
+ println!(" 止损比例: {}%", config.strategy.stop_loss_ratio * Decimal::from(100));
|
|
|
+ println!();
|
|
|
+ println!("网络配置:");
|
|
|
+ println!(" 环境: {}", config.get_network_name());
|
|
|
+ println!();
|
|
|
+ println!("账户配置:");
|
|
|
+ println!(" API Key: {}...", &config.account.api_key[..20.min(config.account.api_key.len())]);
|
|
|
+ println!(" Vault Number: {}", config.account.vault_number);
|
|
|
+ println!("=====================================");
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn test_validation() {
|
|
|
+ // 测试无效配置
|
|
|
+ let invalid_toml = r#"
|
|
|
+ [strategy]
|
|
|
+ order_quantity = "-0.001"
|
|
|
+ market = "BTC-USD"
|
|
|
+ stop_loss_ratio = "0.02"
|
|
|
+
|
|
|
+ [network]
|
|
|
+ is_testnet = true
|
|
|
+ testnet_url = "https://test.com"
|
|
|
+ production_url = "https://prod.com"
|
|
|
+
|
|
|
+ [account]
|
|
|
+ api_key = "test"
|
|
|
+ stark_public_key = "test"
|
|
|
+ stark_private_key = "test"
|
|
|
+ vault_number = 1
|
|
|
+ "#;
|
|
|
+
|
|
|
+ let config: Config = toml::from_str(invalid_toml).unwrap();
|
|
|
+ assert!(config.validate().is_err());
|
|
|
+ }
|
|
|
+}
|