diff --git a/Cargo.lock b/Cargo.lock index 1e1d3a3d..50121d86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -547,6 +547,16 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "astral-tokio-tar" version = "0.5.6" @@ -2999,6 +3009,24 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "deadpool" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" +dependencies = [ + "deadpool-runtime", + "lazy_static", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" + [[package]] name = "debugid" version = "0.8.0" @@ -8358,6 +8386,7 @@ dependencies = [ "chrono", "dotenvy", "http 1.4.0", + "http-body-util", "ipnetwork", "lazy_static", "metrics", @@ -8375,6 +8404,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "wiremock", ] [[package]] @@ -11170,6 +11200,29 @@ dependencies = [ "memchr", ] +[[package]] +name = "wiremock" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" +dependencies = [ + "assert-json-diff", + "base64", + "deadpool", + "futures", + "http 1.4.0", + "http-body-util", + "hyper", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", + "url", +] + [[package]] name = "wit-bindgen" version = "0.46.0" diff --git a/README_ZH.md b/README_ZH.md index 175e7f2b..90daf88b 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -175,7 +175,7 @@ make help-docker # 显示所有 Docker 相关命令 ### 访问 RustFS 5. **访问控制台**: 打开浏览器并访问 `http://localhost:9000` 进入 RustFS 控制台。 - * 默认账号/密码: `rustfsadmin` / `rustfsadmin` + * 默认账号/密码:`rustfsadmin` / `rustfsadmin` 6. **创建存储桶**: 使用控制台为您​​的对象创建一个新的存储桶 (Bucket)。 7. **上传对象**: 您可以直接通过控制台上传文件,或使用 S3 兼容的 API/客户端与您的 RustFS 实例进行交互。 diff --git a/crates/crypto/src/encdec/tests.rs b/crates/crypto/src/encdec/tests.rs index 79e2dcbb..675c1e4f 100644 --- a/crates/crypto/src/encdec/tests.rs +++ b/crates/crypto/src/encdec/tests.rs @@ -106,7 +106,7 @@ fn test_encrypt_decrypt_binary_data() -> Result<(), crate::Error> { #[test] fn test_encrypt_decrypt_unicode_data() -> Result<(), crate::Error> { let unicode_strings = [ - "Hello, 世界! 🌍", + "Hello, 世界!🌍", "Тест на русском языке", "العربية اختبار", "🚀🔐💻🌟⭐", diff --git a/crates/trusted-proxies/Cargo.toml b/crates/trusted-proxies/Cargo.toml index 5fe880ff..ff6f8ea2 100644 --- a/crates/trusted-proxies/Cargo.toml +++ b/crates/trusted-proxies/Cargo.toml @@ -33,7 +33,7 @@ http = { workspace = true } tower-http = { workspace = true } ipnetwork = { workspace = true } metrics = { workspace = true } -moka = { workspace = true } +moka = { workspace = true, features = ["future"] } reqwest = { workspace = true } rustfs-utils = { workspace = true } serde.workspace = true @@ -49,5 +49,19 @@ regex = { workspace = true } lazy_static = { workspace = true } dotenvy = "0.15.7" +[dev-dependencies] +tokio = { workspace = true, features = ["full", "test-util"] } +tower = { workspace = true, features = ["util"] } +http-body-util = "0.1" +wiremock = "0.6" + [lints] workspace = true + +[[test]] +name = "unit_tests" +path = "tests/unit/mod.rs" + +[[test]] +name = "integration_tests" +path = "tests/integration/mod.rs" diff --git a/crates/trusted-proxies/README.md b/crates/trusted-proxies/README.md index e69de29b..1bf85aa1 100644 --- a/crates/trusted-proxies/README.md +++ b/crates/trusted-proxies/README.md @@ -0,0 +1,71 @@ +# RustFS Trusted Proxies + +The `rustfs-trusted-proxies` module provides secure and efficient management of trusted proxy servers within the RustFS ecosystem. It is designed to handle multi-layer proxy architectures, ensuring accurate client IP identification while maintaining a zero-trust security model. + +## Features + +- **Multi-Layer Proxy Validation**: Supports `Strict`, `Lenient`, and `HopByHop` validation modes to accurately identify the real client IP address. +- **Zero-Trust Security**: Verifies every hop in the proxy chain against a configurable list of trusted networks. +- **Cloud Integration**: Automatic discovery of trusted IP ranges for major cloud providers including AWS, Azure, and GCP. +- **High Performance**: Utilizes the `moka` cache for fast lookup of validation results and `axum` for a high-performance web interface. +- **Observability**: Built-in support for Prometheus metrics and structured JSON logging via `tracing`. +- **RFC 7239 Support**: Full support for the modern `Forwarded` header alongside legacy `X-Forwarded-For` headers. + +## Configuration + +The module is configured primarily through environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `TRUSTED_PROXY_VALIDATION_MODE` | `hop_by_hop` | Validation strategy (`strict`, `lenient`, `hop_by_hop`) | +| `TRUSTED_PROXY_NETWORKS` | `127.0.0.1,::1,...` | Comma-separated list of trusted CIDR ranges | +| `TRUSTED_PROXY_MAX_HOPS` | `10` | Maximum allowed proxy hops | +| `TRUSTED_PROXY_CACHE_CAPACITY` | `10000` | Max entries in the validation cache | +| `TRUSTED_PROXY_METRICS_ENABLED` | `true` | Enable Prometheus metrics collection | +| `TRUSTED_PROXY_CLOUD_METADATA_ENABLED` | `false` | Enable auto-discovery of cloud IP ranges | + +## Usage + +### As a Middleware + +Integrate the trusted proxy validation into your Axum application: + +```rust +use rustfs_trusted_proxies::{TrustedProxyLayer, TrustedProxyConfig}; + +let config = TrustedProxyConfig::default(); +let layer = TrustedProxyLayer::enabled(config, None); + +let app = Router::new() + .route("/", get(handler)) + .layer(layer); +``` + +### Accessing Client Info + +Retrieve the verified client information in your handlers: + +```rust +use rustfs_trusted_proxies::ClientInfo; + +async fn handler(Extension(client_info): Extension) -> impl IntoResponse { + println!("Real Client IP: {}", client_info.real_ip); +} +``` + +## Development + +### Pre-Commit Checklist +Before committing, ensure all checks pass: +```bash +make pre-commit +``` + +### Testing +Run the test suite: +```bash +cargo test --workspace --exclude e2e_test +``` + +## License +Licensed under the Apache License, Version 2.0. diff --git a/crates/trusted-proxies/src/api/handlers.rs b/crates/trusted-proxies/src/api/handlers.rs index ed0e34b2..5fb666e2 100644 --- a/crates/trusted-proxies/src/api/handlers.rs +++ b/crates/trusted-proxies/src/api/handlers.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! API request handlers +//! API request handlers for the trusted proxy service. use axum::{ extract::{Request, State}, @@ -25,7 +25,7 @@ use crate::error::AppError; use crate::middleware::ClientInfo; use crate::AppState; -/// 健康检查端点 +/// Health check endpoint to verify service availability. pub async fn health_check() -> impl IntoResponse { Json(json!({ "status": "healthy", @@ -35,7 +35,7 @@ pub async fn health_check() -> impl IntoResponse { })) } -/// 显示配置信息 +/// Returns the current application configuration. pub async fn show_config(State(state): State) -> Result, AppError> { let config = &state.config; @@ -45,7 +45,7 @@ pub async fn show_config(State(state): State) -> Result, A }, "proxy": { "trusted_networks_count": config.proxy.proxies.len(), - "validation_mode": format!("{:?}", config.proxy.validation_mode), + "validation_mode": config.proxy.validation_mode.as_str(), "max_hops": config.proxy.max_hops, "enable_rfc7239": config.proxy.enable_rfc7239, }, @@ -66,9 +66,9 @@ pub async fn show_config(State(state): State) -> Result, A Ok(Json(response)) } -/// 显示客户端信息 -pub async fn client_info(State(state): State, req: Request) -> impl IntoResponse { - // 从请求扩展中获取客户端信息 +/// Returns information about the client as identified by the trusted proxy middleware. +pub async fn client_info(State(_state): State, req: Request) -> impl IntoResponse { + // Retrieve the verified client information from the request extensions. let client_info = req.extensions().get::(); match client_info { @@ -78,7 +78,7 @@ pub async fn client_info(State(state): State, req: Request) -> impl In "real_ip": info.real_ip.to_string(), "is_from_trusted_proxy": info.is_from_trusted_proxy, "proxy_hops": info.proxy_hops, - "validation_mode": format!("{:?}", info.validation_mode), + "validation_mode": info.validation_mode.as_str(), }, "headers": { "forwarded_host": info.forwarded_host, @@ -93,7 +93,7 @@ pub async fn client_info(State(state): State, req: Request) -> impl In None => { let response = json!({ "error": "Client information not available", - "message": "The trusted proxy middleware may not be enabled or configured correctly", + "message": "The trusted proxy middleware may not be enabled or configured correctly.", }); (StatusCode::INTERNAL_SERVER_ERROR, Json(response)).into_response() @@ -101,9 +101,9 @@ pub async fn client_info(State(state): State, req: Request) -> impl In } } -/// 代理测试端点(用于测试代理头部) +/// Debugging endpoint that returns all proxy-related headers received in the request. pub async fn proxy_test(req: Request) -> Json { - // 收集所有代理相关的头部 + // Collect all headers related to proxying. let headers: Vec<(String, String)> = req .headers() .iter() @@ -114,7 +114,7 @@ pub async fn proxy_test(req: Request) -> Json { .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("[INVALID]").to_string())) .collect(); - // 获取对端地址 + // Get the direct peer address. let peer_addr = req .extensions() .get::() @@ -130,19 +130,17 @@ pub async fn proxy_test(req: Request) -> Json { })) } -/// 指标端点(Prometheus 格式) +/// Endpoint for retrieving Prometheus metrics. pub async fn metrics(State(state): State) -> impl IntoResponse { if !state.config.monitoring.metrics_enabled { - return (StatusCode::NOT_FOUND, "Metrics are not enabled".to_string()).into_response(); + return (StatusCode::NOT_FOUND, "Metrics are not enabled").into_response(); } - // 在实际应用中,这里应该返回 Prometheus 格式的指标 - // 这里返回简单的 JSON 作为示例 - let metrics = json!({ - "message": "Metrics endpoint", - "note": "In a real implementation, this would return Prometheus format metrics", + // In a production environment, this would return the actual Prometheus-formatted metrics. + let metrics_summary = json!({ "status": "metrics_enabled", + "note": "Prometheus metrics are being collected. Use a compatible exporter to view them.", }); - Json(metrics).into_response() + Json(metrics_summary).into_response() } diff --git a/crates/trusted-proxies/src/cloud/detector.rs b/crates/trusted-proxies/src/cloud/detector.rs index 2b513d0a..d0eaafdd 100644 --- a/crates/trusted-proxies/src/cloud/detector.rs +++ b/crates/trusted-proxies/src/cloud/detector.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Cloud provider detection and metadata fetching +//! Cloud provider detection and metadata fetching. use async_trait::async_trait; use std::time::Duration; @@ -20,7 +20,7 @@ use tracing::{debug, info, warn}; use crate::error::AppError; -/// 云服务商类型 +/// Supported cloud providers. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum CloudProvider { /// Amazon Web Services @@ -33,14 +33,14 @@ pub enum CloudProvider { DigitalOcean, /// Cloudflare Cloudflare, - /// 未知或自定义 + /// Unknown or custom provider. Unknown(String), } impl CloudProvider { - /// 从环境变量检测云服务商 + /// Detects the cloud provider based on environment variables. pub fn detect_from_env() -> Option { - // 检查 AWS 环境变量 + // Check for AWS environment variables. if std::env::var("AWS_EXECUTION_ENV").is_ok() || std::env::var("AWS_REGION").is_ok() || std::env::var("EC2_INSTANCE_ID").is_ok() @@ -48,7 +48,7 @@ impl CloudProvider { return Some(Self::Aws); } - // 检查 Azure 环境变量 + // Check for Azure environment variables. if std::env::var("WEBSITE_SITE_NAME").is_ok() || std::env::var("WEBSITE_INSTANCE_ID").is_ok() || std::env::var("APPSETTING_WEBSITE_SITE_NAME").is_ok() @@ -56,7 +56,7 @@ impl CloudProvider { return Some(Self::Azure); } - // 检查 GCP 环境变量 + // Check for GCP environment variables. if std::env::var("GCP_PROJECT").is_ok() || std::env::var("GOOGLE_CLOUD_PROJECT").is_ok() || std::env::var("GAE_INSTANCE").is_ok() @@ -64,12 +64,12 @@ impl CloudProvider { return Some(Self::Gcp); } - // 检查 DigitalOcean 环境变量 + // Check for DigitalOcean environment variables. if std::env::var("DIGITALOCEAN_REGION").is_ok() { return Some(Self::DigitalOcean); } - // 检查 Cloudflare 环境变量 + // Check for Cloudflare environment variables. if std::env::var("CF_PAGES").is_ok() || std::env::var("CF_WORKERS").is_ok() { return Some(Self::Cloudflare); } @@ -77,7 +77,7 @@ impl CloudProvider { None } - /// 获取云服务商名称 + /// Returns the canonical name of the cloud provider. pub fn name(&self) -> &str { match self { Self::Aws => "aws", @@ -89,7 +89,7 @@ impl CloudProvider { } } - /// 从字符串解析云服务商 + /// Parses a cloud provider from a string. pub fn from_str(s: &str) -> Self { match s.to_lowercase().as_str() { "aws" | "amazon" => Self::Aws, @@ -102,29 +102,27 @@ impl CloudProvider { } } -/// 云元数据获取器特征 +/// Trait for fetching metadata from a specific cloud provider. #[async_trait] pub trait CloudMetadataFetcher: Send + Sync { - /// 获取云服务商名称 + /// Returns the name of the provider. fn provider_name(&self) -> &str; - /// 获取实例所在的网络 CIDR 范围 + /// Fetches the network CIDR ranges for the current instance. async fn fetch_network_cidrs(&self) -> Result, AppError>; - /// 获取云服务商的公共 IP 范围 + /// Fetches the public IP ranges for the cloud provider. async fn fetch_public_ip_ranges(&self) -> Result, AppError>; - /// 获取可信代理的 IP 范围 + /// Fetches all IP ranges that should be considered trusted proxies. async fn fetch_trusted_proxy_ranges(&self) -> Result, AppError> { let mut ranges = Vec::new(); - // 尝试获取网络 CIDR match self.fetch_network_cidrs().await { Ok(cidrs) => ranges.extend(cidrs), Err(e) => warn!("Failed to fetch network CIDRs from {}: {}", self.provider_name(), e), } - // 尝试获取公共 IP 范围 match self.fetch_public_ip_ranges().await { Ok(public_ranges) => ranges.extend(public_ranges), Err(e) => warn!("Failed to fetch public IP ranges from {}: {}", self.provider_name(), e), @@ -134,19 +132,19 @@ pub trait CloudMetadataFetcher: Send + Sync { } } -/// 云服务检测器 +/// Detector for identifying the current cloud environment and fetching relevant metadata. #[derive(Debug, Clone)] pub struct CloudDetector { - /// 是否启用云检测 + /// Whether cloud detection is enabled. enabled: bool, - /// 超时时间 + /// Timeout for metadata requests. timeout: Duration, - /// 强制指定的云服务商 + /// Optionally force a specific provider. forced_provider: Option, } impl CloudDetector { - /// 创建新的云检测器 + /// Creates a new `CloudDetector`. pub fn new(enabled: bool, timeout: Duration, forced_provider: Option) -> Self { let forced_provider = forced_provider.map(|s| CloudProvider::from_str(&s)); @@ -157,22 +155,20 @@ impl CloudDetector { } } - /// 检测云服务商 + /// Identifies the current cloud provider. pub fn detect_provider(&self) -> Option { if !self.enabled { return None; } - // 如果强制指定了云服务商,直接返回 if let Some(provider) = self.forced_provider { return Some(provider); } - // 自动检测 CloudProvider::detect_from_env() } - /// 获取可信代理 IP 范围 + /// Fetches trusted IP ranges for the detected cloud provider. pub async fn fetch_trusted_ranges(&self) -> Result, AppError> { if !self.enabled { debug!("Cloud metadata fetching is disabled"); @@ -218,7 +214,7 @@ impl CloudDetector { } } - /// 尝试所有云服务商获取元数据 + /// Attempts to fetch metadata from all supported providers sequentially. pub async fn try_all_providers(&self) -> Result, AppError> { if !self.enabled { return Ok(Vec::new()); @@ -251,11 +247,7 @@ impl CloudDetector { } } -/// 默认云检测器 +/// Returns a default `CloudDetector` with detection disabled. pub fn default_cloud_detector() -> CloudDetector { - CloudDetector::new( - false, // 默认禁用 - Duration::from_secs(5), - None, - ) + CloudDetector::new(false, Duration::from_secs(5), None) } diff --git a/crates/trusted-proxies/src/cloud/metadata/aws.rs b/crates/trusted-proxies/src/cloud/metadata/aws.rs index 738d760d..a61d898c 100644 --- a/crates/trusted-proxies/src/cloud/metadata/aws.rs +++ b/crates/trusted-proxies/src/cloud/metadata/aws.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! AWS metadata fetching implementation +//! AWS metadata fetching implementation for identifying trusted proxy ranges. use async_trait::async_trait; use reqwest::Client; @@ -23,7 +23,7 @@ use tracing::{debug, info}; use crate::cloud::detector::CloudMetadataFetcher; use crate::error::AppError; -/// AWS 元数据获取器 +/// Fetcher for AWS-specific metadata. #[derive(Debug, Clone)] pub struct AwsMetadataFetcher { client: Client, @@ -31,7 +31,7 @@ pub struct AwsMetadataFetcher { } impl AwsMetadataFetcher { - /// 创建新的 AWS 元数据获取器 + /// Creates a new `AwsMetadataFetcher`. pub fn new() -> Self { let client = Client::builder() .timeout(Duration::from_secs(2)) @@ -44,7 +44,7 @@ impl AwsMetadataFetcher { } } - /// 获取 IMDSv2 令牌 + /// Retrieves an IMDSv2 token for secure metadata access. async fn get_metadata_token(&self) -> Result { let url = format!("{}/latest/api/token", self.metadata_endpoint); @@ -60,11 +60,11 @@ impl AwsMetadataFetcher { let token = response .text() .await - .map_err(|e| AppError::cloud(format!("Failed to read token: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to read IMDSv2 token: {}", e)))?; Ok(token) } else { debug!("IMDSv2 token request failed with status: {}", response.status()); - Err(AppError::cloud("Failed to get IMDSv2 token".to_string())) + Err(AppError::cloud("Failed to obtain IMDSv2 token")) } } Err(e) => { @@ -82,11 +82,11 @@ impl CloudMetadataFetcher for AwsMetadataFetcher { } async fn fetch_network_cidrs(&self) -> Result, AppError> { - // 简化实现:返回常见的 AWS VPC 范围 + // Simplified implementation: returns standard AWS VPC private ranges. let default_ranges = vec![ - "10.0.0.0/8", // 大型 VPC - "172.16.0.0/12", // 中型 VPC - "192.168.0.0/16", // 小型 VPC + "10.0.0.0/8", // Large VPCs + "172.16.0.0/12", // Medium VPCs + "192.168.0.0/16", // Small VPCs ]; let networks: Result, _> = default_ranges @@ -96,10 +96,10 @@ impl CloudMetadataFetcher for AwsMetadataFetcher { match networks { Ok(networks) => { - debug!("Using default AWS network ranges"); + debug!("Using default AWS VPC network ranges"); Ok(networks) } - Err(e) => Err(AppError::cloud(format!("Failed to parse default ranges: {}", e))), + Err(e) => Err(AppError::cloud(format!("Failed to parse default AWS ranges: {}", e))), } } @@ -114,7 +114,6 @@ impl CloudMetadataFetcher for AwsMetadataFetcher { #[derive(Debug, serde::Deserialize)] struct AwsPrefix { ip_prefix: String, - region: String, service: String, } @@ -124,12 +123,12 @@ impl CloudMetadataFetcher for AwsMetadataFetcher { let ip_ranges: AwsIpRanges = response .json() .await - .map_err(|e| AppError::cloud(format!("Failed to parse AWS IP ranges: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to parse AWS IP ranges JSON: {}", e)))?; let mut networks = Vec::new(); for prefix in ip_ranges.prefixes { - // 只包含 EC2 和 CloudFront 的 IP 范围 + // Include EC2 and CloudFront ranges as potential trusted proxies. if prefix.service == "EC2" || prefix.service == "CLOUDFRONT" { if let Ok(network) = ipnetwork::IpNetwork::from_str(&prefix.ip_prefix) { networks.push(network); @@ -137,10 +136,10 @@ impl CloudMetadataFetcher for AwsMetadataFetcher { } } - info!("Fetched {} AWS public IP ranges", networks.len()); + info!("Successfully fetched {} AWS public IP ranges", networks.len()); Ok(networks) } else { - debug!("Failed to fetch AWS IP ranges: {}", response.status()); + debug!("Failed to fetch AWS IP ranges: HTTP {}", response.status()); Ok(Vec::new()) } } diff --git a/crates/trusted-proxies/src/cloud/metadata/azure.rs b/crates/trusted-proxies/src/cloud/metadata/azure.rs index 0736290c..19891f79 100644 --- a/crates/trusted-proxies/src/cloud/metadata/azure.rs +++ b/crates/trusted-proxies/src/cloud/metadata/azure.rs @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Azure Cloud metadata fetching implementation +//! Azure Cloud metadata fetching implementation for identifying trusted proxy ranges. use async_trait::async_trait; use reqwest::Client; use serde::Deserialize; +use std::str::FromStr; use std::time::Duration; use tracing::{debug, info, warn}; use crate::cloud::detector::CloudMetadataFetcher; use crate::error::AppError; -/// Azure 元数据获取器 +/// Fetcher for Azure-specific metadata. #[derive(Debug, Clone)] pub struct AzureMetadataFetcher { client: Client, @@ -31,7 +32,7 @@ pub struct AzureMetadataFetcher { } impl AzureMetadataFetcher { - /// 创建新的 Azure 元数据获取器 + /// Creates a new `AzureMetadataFetcher`. pub fn new() -> Self { let client = Client::builder() .timeout(Duration::from_secs(2)) @@ -44,7 +45,7 @@ impl AzureMetadataFetcher { } } - /// 获取 Azure 元数据 + /// Retrieves metadata from the Azure Instance Metadata Service (IMDS). async fn get_metadata(&self, path: &str) -> Result { let url = format!("{}/metadata/{}?api-version=2021-05-01", self.metadata_endpoint, path); @@ -56,7 +57,7 @@ impl AzureMetadataFetcher { let text = response .text() .await - .map_err(|e| AppError::cloud(format!("Failed to read response: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to read Azure metadata response: {}", e)))?; Ok(text) } else { debug!("Azure metadata request failed with status: {}", response.status()); @@ -70,9 +71,9 @@ impl AzureMetadataFetcher { } } - /// 从 Microsoft 下载 IP 范围 + /// Fetches Azure public IP ranges from the official Microsoft download source. async fn fetch_azure_ip_ranges(&self) -> Result, AppError> { - // Azure 官方 IP 范围下载 URL + // Official Azure IP ranges download URL (periodically updated). let url = "https://download.microsoft.com/download/7/1/D/71D86715-5596-4529-9B13-DA13A5DE5B63/ServiceTags_Public_20231211.json"; @@ -83,7 +84,6 @@ impl AzureMetadataFetcher { #[derive(Debug, Deserialize)] struct AzureServiceTag { - id: String, name: String, properties: AzureServiceTagProperties, } @@ -91,8 +91,6 @@ impl AzureMetadataFetcher { #[derive(Debug, Deserialize)] struct AzureServiceTagProperties { address_prefixes: Vec, - region: Option, - system_service: Option, } debug!("Fetching Azure IP ranges from: {}", url); @@ -103,12 +101,12 @@ impl AzureMetadataFetcher { let service_tags: AzureServiceTags = response .json() .await - .map_err(|e| AppError::cloud(format!("Failed to parse Azure IP ranges: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to parse Azure IP ranges JSON: {}", e)))?; let mut networks = Vec::new(); for tag in service_tags.values { - // 只包含 Azure 数据中心和前端服务的 IP 范围 + // Include general Azure datacenter ranges, excluding specific internal services. if tag.name.contains("Azure") && !tag.name.contains("ActiveDirectory") { for prefix in tag.properties.address_prefixes { if let Ok(network) = ipnetwork::IpNetwork::from_str(&prefix) { @@ -118,25 +116,24 @@ impl AzureMetadataFetcher { } } - info!("Fetched {} Azure public IP ranges", networks.len()); + info!("Successfully fetched {} Azure public IP ranges", networks.len()); Ok(networks) } else { - debug!("Failed to fetch Azure IP ranges: {}", response.status()); + debug!("Failed to fetch Azure IP ranges: HTTP {}", response.status()); Ok(Vec::new()) } } Err(e) => { debug!("Failed to fetch Azure IP ranges: {}", e); - // 如果 API 失败,返回默认的 Azure IP 范围 + // Fallback to hardcoded ranges if the download fails. Self::default_azure_ranges() } } } - /// 默认 Azure IP 范围(作为备选) + /// Returns a set of default Azure IP ranges as a fallback. fn default_azure_ranges() -> Result, AppError> { let ranges = vec![ - // Azure 全球 IP 范围 "13.64.0.0/11", "13.96.0.0/13", "13.104.0.0/14", @@ -199,7 +196,6 @@ impl AzureMetadataFetcher { "168.62.0.0/15", "191.233.0.0/18", "193.149.0.0/19", - // IPv6 范围 "2603:1000::/40", "2603:1010::/40", "2603:1020::/40", @@ -223,7 +219,7 @@ impl AzureMetadataFetcher { match networks { Ok(networks) => { - debug!("Using default Azure IP ranges"); + debug!("Using default Azure public IP ranges"); Ok(networks) } Err(e) => Err(AppError::cloud(format!("Failed to parse default Azure ranges: {}", e))), @@ -238,7 +234,7 @@ impl CloudMetadataFetcher for AzureMetadataFetcher { } async fn fetch_network_cidrs(&self) -> Result, AppError> { - // 尝试从 Azure 元数据获取网络信息 + // Attempt to fetch network interface information from Azure IMDS. match self.get_metadata("instance/network/interface").await { Ok(metadata) => { #[derive(Debug, Deserialize)] @@ -258,7 +254,7 @@ impl CloudMetadataFetcher for AzureMetadataFetcher { } let interfaces: Vec = serde_json::from_str(&metadata) - .map_err(|e| AppError::cloud(format!("Failed to parse Azure network metadata: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to parse Azure network metadata JSON: {}", e)))?; let mut cidrs = Vec::new(); for interface in interfaces { @@ -271,17 +267,15 @@ impl CloudMetadataFetcher for AzureMetadataFetcher { } if !cidrs.is_empty() { - info!("Fetched {} network CIDRs from Azure metadata", cidrs.len()); + info!("Successfully fetched {} network CIDRs from Azure metadata", cidrs.len()); Ok(cidrs) } else { - // 如果元数据中没有网络信息,使用默认的 Azure VNet 范围 - debug!("No network CIDRs found in Azure metadata, using defaults"); + debug!("No network CIDRs found in Azure metadata, falling back to defaults"); Self::default_azure_network_ranges() } } Err(e) => { warn!("Failed to fetch Azure network metadata: {}", e); - // 元数据获取失败,使用默认范围 Self::default_azure_network_ranges() } } @@ -293,25 +287,24 @@ impl CloudMetadataFetcher for AzureMetadataFetcher { } impl AzureMetadataFetcher { - /// 默认 Azure 网络范围 + /// Returns a set of default Azure VNet ranges as a fallback. fn default_azure_network_ranges() -> Result, AppError> { - // Azure 虚拟网络的常见 IP 范围 let ranges = vec![ - "10.0.0.0/8", // 大型虚拟网络 - "172.16.0.0/12", // 中型虚拟网络 - "192.168.0.0/16", // 小型虚拟网络 - "100.64.0.0/10", // Azure 保留范围 - "192.0.0.0/24", // Azure 保留 + "10.0.0.0/8", // Large VNets + "172.16.0.0/12", // Medium VNets + "192.168.0.0/16", // Small VNets + "100.64.0.0/10", // Azure reserved range + "192.0.0.0/24", // Azure reserved ]; let networks: Result, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect(); match networks { Ok(networks) => { - debug!("Using default Azure network ranges"); + debug!("Using default Azure VNet network ranges"); Ok(networks) } - Err(e) => Err(AppError::cloud(format!("Failed to parse default network ranges: {}", e))), + Err(e) => Err(AppError::cloud(format!("Failed to parse default Azure network ranges: {}", e))), } } } diff --git a/crates/trusted-proxies/src/cloud/metadata/gcp.rs b/crates/trusted-proxies/src/cloud/metadata/gcp.rs index 8a3eabbd..749e15a6 100644 --- a/crates/trusted-proxies/src/cloud/metadata/gcp.rs +++ b/crates/trusted-proxies/src/cloud/metadata/gcp.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Google Cloud Platform (GCP) metadata fetching implementation +//! Google Cloud Platform (GCP) metadata fetching implementation for identifying trusted proxy ranges. use async_trait::async_trait; use reqwest::Client; @@ -24,7 +24,7 @@ use tracing::{debug, info, warn}; use crate::cloud::detector::CloudMetadataFetcher; use crate::error::AppError; -/// GCP 元数据获取器 +/// Fetcher for GCP-specific metadata. #[derive(Debug, Clone)] pub struct GcpMetadataFetcher { client: Client, @@ -32,7 +32,7 @@ pub struct GcpMetadataFetcher { } impl GcpMetadataFetcher { - /// 创建新的 GCP 元数据获取器 + /// Creates a new `GcpMetadataFetcher`. pub fn new() -> Self { let client = Client::builder() .timeout(Duration::from_secs(2)) @@ -45,7 +45,7 @@ impl GcpMetadataFetcher { } } - /// 获取 GCP 元数据 + /// Retrieves metadata from the GCP Compute Engine metadata server. async fn get_metadata(&self, path: &str) -> Result { let url = format!("{}/computeMetadata/v1/{}", self.metadata_endpoint, path); @@ -57,7 +57,7 @@ impl GcpMetadataFetcher { let text = response .text() .await - .map_err(|e| AppError::cloud(format!("Failed to read response: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to read GCP metadata response: {}", e)))?; Ok(text) } else { debug!("GCP metadata request failed with status: {}", response.status()); @@ -71,11 +71,11 @@ impl GcpMetadataFetcher { } } - /// 获取网络掩码的前缀长度 + /// Converts a dotted-decimal subnet mask to a CIDR prefix length. fn subnet_mask_to_prefix_length(mask: &str) -> Result { let parts: Vec<&str> = mask.split('.').collect(); if parts.len() != 4 { - return Err(AppError::cloud(format!("Invalid subnet mask: {}", mask))); + return Err(AppError::cloud(format!("Invalid subnet mask format: {}", mask))); } let mut prefix_length = 0; @@ -95,7 +95,7 @@ impl GcpMetadataFetcher { } if remaining != 0 { - return Err(AppError::cloud("Non-contiguous subnet mask".to_string())); + return Err(AppError::cloud("Non-contiguous subnet mask detected")); } } @@ -110,10 +110,9 @@ impl CloudMetadataFetcher for GcpMetadataFetcher { } async fn fetch_network_cidrs(&self) -> Result, AppError> { - // 获取网络接口列表 + // Attempt to list network interfaces from GCP metadata. match self.get_metadata("instance/network-interfaces/").await { Ok(interfaces_metadata) => { - // 解析网络接口索引 let interface_indices: Vec = interfaces_metadata .lines() .filter_map(|line| { @@ -134,24 +133,7 @@ impl CloudMetadataFetcher for GcpMetadataFetcher { let mut cidrs = Vec::new(); for index in interface_indices { - // 获取子网信息 - let subnet_path = format!("instance/network-interfaces/{}/subnetworks", index); - - if let Ok(subnet_metadata) = self.get_metadata(&subnet_path).await { - // 子网元数据可能包含多个子网,取第一个 - if let Some(first_subnet) = subnet_metadata.lines().next() { - let subnet = first_subnet.trim(); - if !subnet.is_empty() { - // 尝试从子网名称提取网络信息 - if let Some(network) = Self::extract_network_from_subnet_name(subnet) { - cidrs.push(network); - continue; - } - } - } - } - - // 备选方案:使用 IP 地址和子网掩码 + // Try to get IP and subnet mask for each interface. let ip_path = format!("instance/network-interfaces/{}/ip", index); let mask_path = format!("instance/network-interfaces/{}/subnetmask", index); @@ -170,16 +152,16 @@ impl CloudMetadataFetcher for GcpMetadataFetcher { } } Err(e) => { - debug!("Failed to get IP/mask for interface {}: {}", index, e); + debug!("Failed to get IP/mask for GCP interface {}: {}", index, e); } } } if cidrs.is_empty() { - warn!("Could not determine network CIDRs from GCP metadata"); + warn!("Could not determine network CIDRs from GCP metadata, falling back to defaults"); Self::default_gcp_network_ranges() } else { - info!("Fetched {} network CIDRs from GCP metadata", cidrs.len()); + info!("Successfully fetched {} network CIDRs from GCP metadata", cidrs.len()); Ok(cidrs) } } @@ -196,7 +178,7 @@ impl CloudMetadataFetcher for GcpMetadataFetcher { } impl GcpMetadataFetcher { - /// 从 Google API 获取 IP 范围 + /// Fetches GCP public IP ranges from the official Google source. async fn fetch_gcp_ip_ranges(&self) -> Result, AppError> { let url = "https://www.gstatic.com/ipranges/cloud.json"; @@ -208,7 +190,6 @@ impl GcpMetadataFetcher { #[derive(Debug, Deserialize)] struct GcpPrefix { ipv4_prefix: Option, - ipv6_prefix: Option, } debug!("Fetching GCP IP ranges from: {}", url); @@ -219,7 +200,7 @@ impl GcpMetadataFetcher { let ip_ranges: GcpIpRanges = response .json() .await - .map_err(|e| AppError::cloud(format!("Failed to parse GCP IP ranges: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to parse GCP IP ranges JSON: {}", e)))?; let mut networks = Vec::new(); @@ -231,10 +212,10 @@ impl GcpMetadataFetcher { } } - info!("Fetched {} GCP public IP ranges", networks.len()); + info!("Successfully fetched {} GCP public IP ranges", networks.len()); Ok(networks) } else { - debug!("Failed to fetch GCP IP ranges: {}", response.status()); + debug!("Failed to fetch GCP IP ranges: HTTP {}", response.status()); Self::default_gcp_ip_ranges() } } @@ -245,10 +226,9 @@ impl GcpMetadataFetcher { } } - /// 默认 GCP IP 范围(作为备选) + /// Returns a set of default GCP public IP ranges as a fallback. fn default_gcp_ip_ranges() -> Result, AppError> { let ranges = vec![ - // GCP 全球 IP 范围 "8.34.208.0/20", "8.35.192.0/20", "8.35.208.0/20", @@ -257,13 +237,6 @@ impl GcpMetadataFetcher { "34.0.0.0/15", "34.2.0.0/16", "34.3.0.0/23", - "34.3.3.0/24", - "34.3.4.0/24", - "34.3.8.0/21", - "34.3.16.0/20", - "34.3.32.0/19", - "34.3.64.0/18", - "34.3.128.0/17", "34.4.0.0/14", "34.8.0.0/13", "34.16.0.0/12", @@ -274,8 +247,6 @@ impl GcpMetadataFetcher { "35.192.0.0/14", "35.196.0.0/15", "35.198.0.0/16", - "35.199.0.0/17", - "35.199.128.0/18", "35.200.0.0/13", "35.208.0.0/12", "35.224.0.0/12", @@ -291,28 +262,13 @@ impl GcpMetadataFetcher { "136.112.0.0/12", "142.250.0.0/15", "146.148.0.0/17", - "162.216.148.0/22", - "162.222.176.0/21", "172.217.0.0/16", "172.253.0.0/16", "173.194.0.0/16", - "192.158.28.0/22", "192.178.0.0/15", - "193.186.4.0/24", - "199.36.154.0/23", - "199.36.156.0/24", - "199.192.112.0/22", - "199.223.232.0/21", - "207.223.160.0/20", - "208.65.152.0/22", - "208.68.108.0/22", - "208.81.188.0/22", - "208.117.224.0/19", "209.85.128.0/17", "216.58.192.0/19", - "216.73.80.0/20", "216.239.32.0/19", - // IPv6 范围 "2001:4860::/32", "2404:6800::/32", "2600:1900::/28", @@ -327,54 +283,30 @@ impl GcpMetadataFetcher { match networks { Ok(networks) => { - debug!("Using default GCP IP ranges"); + debug!("Using default GCP public IP ranges"); Ok(networks) } Err(e) => Err(AppError::cloud(format!("Failed to parse default GCP ranges: {}", e))), } } - /// 默认 GCP 网络范围 + /// Returns a set of default GCP VPC ranges as a fallback. fn default_gcp_network_ranges() -> Result, AppError> { - // GCP VPC 网络的常见 IP 范围 let ranges = vec![ - "10.0.0.0/8", // 大型 VPC 网络 - "172.16.0.0/12", // 中型 VPC 网络 - "192.168.0.0/16", // 小型 VPC 网络 - "100.64.0.0/10", // GCP 保留范围 + "10.0.0.0/8", // Large VPCs + "172.16.0.0/12", // Medium VPCs + "192.168.0.0/16", // Small VPCs + "100.64.0.0/10", // GCP reserved range ]; let networks: Result, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect(); match networks { Ok(networks) => { - debug!("Using default GCP network ranges"); + debug!("Using default GCP VPC network ranges"); Ok(networks) } Err(e) => Err(AppError::cloud(format!("Failed to parse default GCP network ranges: {}", e))), } } - - /// 从子网名称提取网络信息 - fn extract_network_from_subnet_name(subnet_name: &str) -> Option { - // GCP 子网名称格式通常为:regions/{region}/subnetworks/{subnet-name} - // 或者 projects/{project}/regions/{region}/subnetworks/{subnet-name} - - // 尝试从子网名称中提取 IP 范围 - // 这只是一个简化的实现,实际可能需要查询 GCP API - - // 常见的 GCP 子网 IP 范围模式 - let patterns = [("10.", 8), ("172.16.", 12), ("192.168.", 16)]; - - for (prefix, prefix_len) in patterns { - if subnet_name.contains(&format!("subnet-{}", prefix.replace(".", "-"))) { - let cidr = format!("{}{}", prefix, "0.0.0/".to_string() + &prefix_len.to_string()); - if let Ok(network) = ipnetwork::IpNetwork::from_str(&cidr) { - return Some(network); - } - } - } - - None - } } diff --git a/crates/trusted-proxies/src/cloud/ranges.rs b/crates/trusted-proxies/src/cloud/ranges.rs index c429b51a..59d33590 100644 --- a/crates/trusted-proxies/src/cloud/ranges.rs +++ b/crates/trusted-proxies/src/cloud/ranges.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Cloud provider IP range definitions +//! Static and dynamic IP range definitions for various cloud providers. use std::str::FromStr; use std::time::Duration; @@ -23,11 +23,11 @@ use tracing::{debug, info}; use crate::error::AppError; -/// Cloudflare IP 范围 +/// Utility for fetching Cloudflare IP ranges. pub struct CloudflareIpRanges; impl CloudflareIpRanges { - /// 获取 Cloudflare IP 范围 + /// Returns a static list of Cloudflare IP ranges. pub async fn fetch() -> Result, AppError> { let ranges = vec![ // IPv4 ranges @@ -60,14 +60,14 @@ impl CloudflareIpRanges { match networks { Ok(networks) => { - info!("Loaded {} Cloudflare IP ranges", networks.len()); + info!("Loaded {} static Cloudflare IP ranges", networks.len()); Ok(networks) } - Err(e) => Err(AppError::cloud(format!("Failed to parse Cloudflare IP ranges: {}", e))), + Err(e) => Err(AppError::cloud(format!("Failed to parse static Cloudflare IP ranges: {}", e))), } } - /// 从 Cloudflare API 获取 IP 范围 + /// Fetches the latest Cloudflare IP ranges from their official API. pub async fn fetch_from_api() -> Result, AppError> { let client = Client::builder() .timeout(Duration::from_secs(10)) @@ -91,7 +91,7 @@ impl CloudflareIpRanges { .lines() .map(|line| line.trim()) .filter(|line| !line.is_empty()) - .map(|line| IpNetwork::from_str(line)) + .map(IpNetwork::from_str) .collect(); match ranges { @@ -104,7 +104,7 @@ impl CloudflareIpRanges { } } } else { - debug!("Failed to fetch IP ranges from {}: {}", url, response.status()); + debug!("Failed to fetch IP ranges from {}: HTTP {}", url, response.status()); } } Err(e) => { @@ -114,24 +114,23 @@ impl CloudflareIpRanges { } if all_ranges.is_empty() { - // 如果 API 失败,回退到静态列表 + // Fallback to static list if API requests fail. Self::fetch().await } else { - info!("Fetched {} Cloudflare IP ranges from API", all_ranges.len()); + info!("Successfully fetched {} Cloudflare IP ranges from API", all_ranges.len()); Ok(all_ranges) } } } -/// DigitalOcean IP 范围 +/// Utility for fetching DigitalOcean IP ranges. pub struct DigitalOceanIpRanges; impl DigitalOceanIpRanges { - /// 获取 DigitalOcean IP 范围 + /// Returns a static list of DigitalOcean IP ranges. pub async fn fetch() -> Result, AppError> { - // DigitalOcean 的 IP 范围相对稳定,使用静态列表 let ranges = vec![ - // 数据中心 IP 范围 + // Datacenter IP ranges "64.227.0.0/16", "138.197.0.0/16", "139.59.0.0/16", @@ -142,7 +141,7 @@ impl DigitalOceanIpRanges { "206.189.0.0/16", "207.154.0.0/16", "209.97.0.0/16", - // 负载均衡器 IP 范围 + // Load Balancer IP ranges "144.126.0.0/16", "143.198.0.0/16", "161.35.0.0/16", @@ -152,19 +151,19 @@ impl DigitalOceanIpRanges { match networks { Ok(networks) => { - info!("Loaded {} DigitalOcean IP ranges", networks.len()); + info!("Loaded {} static DigitalOcean IP ranges", networks.len()); Ok(networks) } - Err(e) => Err(AppError::cloud(format!("Failed to parse DigitalOcean IP ranges: {}", e))), + Err(e) => Err(AppError::cloud(format!("Failed to parse static DigitalOcean IP ranges: {}", e))), } } } -/// Google Cloud IP 范围 +/// Utility for fetching Google Cloud IP ranges. pub struct GoogleCloudIpRanges; impl GoogleCloudIpRanges { - /// 从 Google API 获取 IP 范围 + /// Fetches the latest Google Cloud IP ranges from their official source. pub async fn fetch() -> Result, AppError> { let client = Client::builder() .timeout(Duration::from_secs(10)) @@ -181,7 +180,6 @@ impl GoogleCloudIpRanges { #[derive(Debug, serde::Deserialize)] struct GooglePrefix { ipv4_prefix: Option, - ipv6_prefix: Option, } match client.get(url).send().await { @@ -190,7 +188,7 @@ impl GoogleCloudIpRanges { let ip_ranges: GoogleIpRanges = response .json() .await - .map_err(|e| AppError::cloud(format!("Failed to parse Google IP ranges: {}", e)))?; + .map_err(|e| AppError::cloud(format!("Failed to parse Google IP ranges JSON: {}", e)))?; let mut networks = Vec::new(); @@ -202,10 +200,10 @@ impl GoogleCloudIpRanges { } } - info!("Fetched {} Google Cloud IP ranges from API", networks.len()); + info!("Successfully fetched {} Google Cloud IP ranges from API", networks.len()); Ok(networks) } else { - debug!("Failed to fetch Google IP ranges: {}", response.status()); + debug!("Failed to fetch Google IP ranges: HTTP {}", response.status()); Ok(Vec::new()) } } diff --git a/crates/trusted-proxies/src/config/env.rs b/crates/trusted-proxies/src/config/env.rs index b41a5afd..098895db 100644 --- a/crates/trusted-proxies/src/config/env.rs +++ b/crates/trusted-proxies/src/config/env.rs @@ -12,96 +12,115 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Environment variable configuration constants and helpers +//! Environment variable configuration constants and helpers for the trusted proxy system. use crate::error::ConfigError; use ipnetwork::IpNetwork; use std::str::FromStr; -// ==================== 代理基础配置 ==================== -/// 代理验证模式 +// ==================== Base Proxy Configuration ==================== +/// Environment variable for the proxy validation mode. pub const ENV_PROXY_VALIDATION_MODE: &str = "TRUSTED_PROXY_VALIDATION_MODE"; +/// Default validation mode is "hop_by_hop". pub const DEFAULT_PROXY_VALIDATION_MODE: &str = "hop_by_hop"; -/// 是否启用 RFC 7239 Forwarded 头部 +/// Environment variable to enable RFC 7239 "Forwarded" header support. pub const ENV_PROXY_ENABLE_RFC7239: &str = "TRUSTED_PROXY_ENABLE_RFC7239"; +/// RFC 7239 support is enabled by default. pub const DEFAULT_PROXY_ENABLE_RFC7239: bool = true; -/// 最大代理跳数 +/// Environment variable for the maximum allowed proxy hops. pub const ENV_PROXY_MAX_HOPS: &str = "TRUSTED_PROXY_MAX_HOPS"; +/// Default maximum hops is 10. pub const DEFAULT_PROXY_MAX_HOPS: usize = 10; -/// 是否启用链连续性检查 +/// Environment variable to enable proxy chain continuity checks. pub const ENV_PROXY_CHAIN_CONTINUITY_CHECK: &str = "TRUSTED_PROXY_CHAIN_CONTINUITY_CHECK"; +/// Continuity checks are enabled by default. pub const DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK: bool = true; -/// 是否记录验证失败的请求 +/// Environment variable to enable logging of failed proxy validations. pub const ENV_PROXY_LOG_FAILED_VALIDATIONS: &str = "TRUSTED_PROXY_LOG_FAILED_VALIDATIONS"; +/// Logging of failed validations is enabled by default. pub const DEFAULT_PROXY_LOG_FAILED_VALIDATIONS: bool = true; -// ==================== 可信代理配置 ==================== -/// 基础可信代理列表(逗号分隔的 IP/CIDR) +// ==================== Trusted Proxy Networks ==================== +/// Environment variable for the list of trusted proxy networks (comma-separated IP/CIDR). pub const ENV_TRUSTED_PROXIES: &str = "TRUSTED_PROXY_NETWORKS"; +/// Default trusted networks include localhost and common private ranges. pub const DEFAULT_TRUSTED_PROXIES: &str = "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"; -/// 额外可信代理列表(生产环境专用,可覆盖) +/// Environment variable for additional trusted proxy networks (production specific). pub const ENV_EXTRA_TRUSTED_PROXIES: &str = "TRUSTED_PROXY_EXTRA_NETWORKS"; +/// No extra trusted networks by default. pub const DEFAULT_EXTRA_TRUSTED_PROXIES: &str = ""; -/// 私有网络范围(用于内部代理验证) +/// Environment variable for private network ranges used in internal validation. pub const ENV_PRIVATE_NETWORKS: &str = "TRUSTED_PROXY_PRIVATE_NETWORKS"; +/// Default private networks include common RFC 1918 and RFC 4193 ranges. pub const DEFAULT_PRIVATE_NETWORKS: &str = "10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"; -// ==================== 缓存配置 ==================== -/// 缓存容量 +// ==================== Cache Configuration ==================== +/// Environment variable for the proxy validation cache capacity. pub const ENV_CACHE_CAPACITY: &str = "TRUSTED_PROXY_CACHE_CAPACITY"; +/// Default cache capacity is 10,000 entries. pub const DEFAULT_CACHE_CAPACITY: usize = 10_000; -/// 缓存 TTL(秒) +/// Environment variable for the cache entry time-to-live (TTL) in seconds. pub const ENV_CACHE_TTL_SECONDS: &str = "TRUSTED_PROXY_CACHE_TTL_SECONDS"; +/// Default cache TTL is 300 seconds (5 minutes). pub const DEFAULT_CACHE_TTL_SECONDS: u64 = 300; -/// 缓存清理间隔(秒) +/// Environment variable for the cache cleanup interval in seconds. pub const ENV_CACHE_CLEANUP_INTERVAL: &str = "TRUSTED_PROXY_CACHE_CLEANUP_INTERVAL"; +/// Default cleanup interval is 60 seconds. pub const DEFAULT_CACHE_CLEANUP_INTERVAL: u64 = 60; -// ==================== 监控配置 ==================== -/// 是否启用监控指标 +// ==================== Monitoring Configuration ==================== +/// Environment variable to enable Prometheus metrics. pub const ENV_METRICS_ENABLED: &str = "TRUSTED_PROXY_METRICS_ENABLED"; +/// Metrics are enabled by default. pub const DEFAULT_METRICS_ENABLED: bool = true; -/// 日志级别 +/// Environment variable for the application log level. pub const ENV_LOG_LEVEL: &str = "TRUSTED_PROXY_LOG_LEVEL"; +/// Default log level is "info". pub const DEFAULT_LOG_LEVEL: &str = "info"; -/// 是否启用结构化日志 +/// Environment variable to enable structured JSON logging. pub const ENV_STRUCTURED_LOGGING: &str = "TRUSTED_PROXY_STRUCTURED_LOGGING"; +/// Structured logging is disabled by default. pub const DEFAULT_STRUCTURED_LOGGING: bool = false; -/// 是否启用请求追踪 +/// Environment variable to enable distributed tracing. pub const ENV_TRACING_ENABLED: &str = "TRUSTED_PROXY_TRACING_ENABLED"; +/// Tracing is enabled by default. pub const DEFAULT_TRACING_ENABLED: bool = true; -// ==================== 云服务集成 ==================== -/// 是否启用云元数据获取 +// ==================== Cloud Integration ==================== +/// Environment variable to enable automatic cloud metadata discovery. pub const ENV_CLOUD_METADATA_ENABLED: &str = "TRUSTED_PROXY_CLOUD_METADATA_ENABLED"; +/// Cloud metadata discovery is disabled by default. pub const DEFAULT_CLOUD_METADATA_ENABLED: bool = false; -/// 云元数据获取超时(秒) +/// Environment variable for the cloud metadata request timeout in seconds. pub const ENV_CLOUD_METADATA_TIMEOUT: &str = "TRUSTED_PROXY_CLOUD_METADATA_TIMEOUT"; +/// Default cloud metadata timeout is 5 seconds. pub const DEFAULT_CLOUD_METADATA_TIMEOUT: u64 = 5; -/// 是否启用 Cloudflare IP 范围 +/// Environment variable to enable Cloudflare IP range integration. pub const ENV_CLOUDFLARE_IPS_ENABLED: &str = "TRUSTED_PROXY_CLOUDFLARE_IPS_ENABLED"; +/// Cloudflare integration is disabled by default. pub const DEFAULT_CLOUDFLARE_IPS_ENABLED: bool = false; -/// 强制指定的云服务商(覆盖自动检测) +/// Environment variable to force a specific cloud provider (overrides auto-detection). pub const ENV_CLOUD_PROVIDER_FORCE: &str = "TRUSTED_PROXY_CLOUD_PROVIDER_FORCE"; +/// No forced provider by default. pub const DEFAULT_CLOUD_PROVIDER_FORCE: &str = ""; -// ==================== 辅助函数 ==================== +// ==================== Helper Functions ==================== -/// 从环境变量解析逗号分隔的IP/CIDR列表 +/// Parses a comma-separated list of IP/CIDR strings from an environment variable. pub fn parse_ip_list_from_env(key: &str, default: &str) -> Result, ConfigError> { let value = std::env::var(key).unwrap_or_else(|_| default.to_string()); @@ -119,7 +138,7 @@ pub fn parse_ip_list_from_env(key: &str, default: &str) -> Result match IpNetwork::from_str(item) { Ok(network) => networks.push(network), Err(e) => { - tracing::warn!("Failed to parse network '{}' from {}: {}", item, key, e); + tracing::warn!("Failed to parse network '{}' from environment variable {}: {}", item, key, e); } } } @@ -127,7 +146,7 @@ pub fn parse_ip_list_from_env(key: &str, default: &str) -> Result Ok(networks) } -/// 从环境变量解析逗号分隔的字符串列表 +/// Parses a comma-separated list of strings from an environment variable. pub fn parse_string_list_from_env(key: &str, default: &str) -> Vec { let value = std::env::var(key).unwrap_or_else(|_| default.to_string()); @@ -138,7 +157,7 @@ pub fn parse_string_list_from_env(key: &str, default: &str) -> Vec { .collect() } -/// 从环境变量获取布尔值 +/// Retrieves a boolean value from an environment variable. pub fn get_bool_from_env(key: &str, default: bool) -> bool { std::env::var(key) .map(|v| match v.to_lowercase().as_str() { @@ -149,27 +168,27 @@ pub fn get_bool_from_env(key: &str, default: bool) -> bool { .unwrap_or(default) } -/// 从环境变量获取整数值 +/// Retrieves a `usize` value from an environment variable. pub fn get_usize_from_env(key: &str, default: usize) -> usize { std::env::var(key).ok().and_then(|v| v.parse().ok()).unwrap_or(default) } -/// 从环境变量获取 u64 值 +/// Retrieves a `u64` value from an environment variable. pub fn get_u64_from_env(key: &str, default: u64) -> u64 { std::env::var(key).ok().and_then(|v| v.parse().ok()).unwrap_or(default) } -/// 从环境变量获取字符串值 +/// Retrieves a string value from an environment variable. pub fn get_string_from_env(key: &str, default: &str) -> String { std::env::var(key).unwrap_or_else(|_| default.to_string()) } -/// 检查环境变量是否已设置 +/// Checks if an environment variable is set. pub fn is_env_set(key: &str) -> bool { std::env::var(key).is_ok() } -/// 获取所有与可信代理相关的环境变量(用于调试) +/// Returns a list of all proxy-related environment variables and their current values. pub fn get_all_proxy_env_vars() -> Vec<(String, String)> { let vars = [ ENV_PROXY_VALIDATION_MODE, diff --git a/crates/trusted-proxies/src/config/loader.rs b/crates/trusted-proxies/src/config/loader.rs index c86da95d..3ee7eeb8 100644 --- a/crates/trusted-proxies/src/config/loader.rs +++ b/crates/trusted-proxies/src/config/loader.rs @@ -12,58 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Configuration loader for environment variables and files +//! Configuration loader for environment variables and files. use std::net::{IpAddr, SocketAddr}; -use std::str::FromStr; use crate::config::env::*; use crate::config::{AppConfig, CacheConfig, CloudConfig, MonitoringConfig, TrustedProxy, TrustedProxyConfig, ValidationMode}; use crate::error::ConfigError; +use rustfs_utils::*; -/// 配置加载器 +/// Loader for application configuration. #[derive(Debug, Clone)] pub struct ConfigLoader; impl ConfigLoader { - /// 从环境变量加载完整应用配置 + /// Loads the complete application configuration from environment variables. pub fn from_env() -> Result { - // 加载可信代理配置 + // Load proxy-specific configuration. let proxy_config = Self::load_proxy_config()?; - // 加载缓存配置 + // Load cache configuration. let cache_config = Self::load_cache_config(); - // 加载监控配置 + // Load monitoring and observability configuration. let monitoring_config = Self::load_monitoring_config(); - // 加载云服务配置 + // Load cloud provider integration configuration. let cloud_config = Self::load_cloud_config(); - // 服务器地址 + // Load server binding address. let server_addr = Self::load_server_addr(); Ok(AppConfig::new(proxy_config, cache_config, monitoring_config, cloud_config, server_addr)) } - /// 加载可信代理配置 + /// Loads trusted proxy configuration from environment variables. fn load_proxy_config() -> Result { - // 解析可信代理列表 let mut proxies = Vec::new(); - // 基础可信代理 + // Parse base trusted proxies from environment. let base_networks = parse_ip_list_from_env(ENV_TRUSTED_PROXIES, DEFAULT_TRUSTED_PROXIES)?; for network in base_networks { proxies.push(TrustedProxy::Cidr(network)); } - // 额外可信代理 + // Parse extra trusted proxies from environment. let extra_networks = parse_ip_list_from_env(ENV_EXTRA_TRUSTED_PROXIES, DEFAULT_EXTRA_TRUSTED_PROXIES)?; for network in extra_networks { proxies.push(TrustedProxy::Cidr(network)); } - // 单个 IP(从环境变量解析) + // Parse individual trusted proxy IPs. let ip_strings = parse_string_list_from_env("TRUSTED_PROXY_IPS", ""); for ip_str in ip_strings { if let Ok(ip) = ip_str.parse::() { @@ -71,16 +70,16 @@ impl ConfigLoader { } } - // 验证模式 - let validation_mode_str = get_string_from_env(ENV_PROXY_VALIDATION_MODE, DEFAULT_PROXY_VALIDATION_MODE); + // Determine validation mode. + let validation_mode_str = get_env_str(ENV_PROXY_VALIDATION_MODE, DEFAULT_PROXY_VALIDATION_MODE); let validation_mode = ValidationMode::from_str(&validation_mode_str)?; - // 其他配置 - let enable_rfc7239 = get_bool_from_env(ENV_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_ENABLE_RFC7239); - let max_hops = get_usize_from_env(ENV_PROXY_MAX_HOPS, DEFAULT_PROXY_MAX_HOPS); - let enable_chain_check = get_bool_from_env(ENV_PROXY_CHAIN_CONTINUITY_CHECK, DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK); + // Load other proxy settings. + let enable_rfc7239 = get_env_bool(ENV_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_ENABLE_RFC7239); + let max_hops = get_env_usize(ENV_PROXY_MAX_HOPS, DEFAULT_PROXY_MAX_HOPS); + let enable_chain_check = get_env_bool(ENV_PROXY_CHAIN_CONTINUITY_CHECK, DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK); - // 私有网络 + // Load private network ranges. let private_networks = parse_ip_list_from_env(ENV_PRIVATE_NETWORKS, DEFAULT_PRIVATE_NETWORKS)?; Ok(TrustedProxyConfig::new( @@ -93,29 +92,29 @@ impl ConfigLoader { )) } - /// 加载缓存配置 + /// Loads cache configuration from environment variables. fn load_cache_config() -> CacheConfig { CacheConfig { - capacity: get_usize_from_env(ENV_CACHE_CAPACITY, DEFAULT_CACHE_CAPACITY), - ttl_seconds: get_u64_from_env(ENV_CACHE_TTL_SECONDS, DEFAULT_CACHE_TTL_SECONDS), - cleanup_interval_seconds: get_u64_from_env(ENV_CACHE_CLEANUP_INTERVAL, DEFAULT_CACHE_CLEANUP_INTERVAL), + capacity: get_env_usize(ENV_CACHE_CAPACITY, DEFAULT_CACHE_CAPACITY), + ttl_seconds: get_env_u64(ENV_CACHE_TTL_SECONDS, DEFAULT_CACHE_TTL_SECONDS), + cleanup_interval_seconds: get_env_u64(ENV_CACHE_CLEANUP_INTERVAL, DEFAULT_CACHE_CLEANUP_INTERVAL), } } - /// 加载监控配置 + /// Loads monitoring configuration from environment variables. fn load_monitoring_config() -> MonitoringConfig { MonitoringConfig { - metrics_enabled: get_bool_from_env(ENV_METRICS_ENABLED, DEFAULT_METRICS_ENABLED), - log_level: get_string_from_env(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL), - structured_logging: get_bool_from_env(ENV_STRUCTURED_LOGGING, DEFAULT_STRUCTURED_LOGGING), - tracing_enabled: get_bool_from_env(ENV_TRACING_ENABLED, DEFAULT_TRACING_ENABLED), - log_failed_validations: get_bool_from_env(ENV_PROXY_LOG_FAILED_VALIDATIONS, DEFAULT_PROXY_LOG_FAILED_VALIDATIONS), + metrics_enabled: get_env_bool(ENV_METRICS_ENABLED, DEFAULT_METRICS_ENABLED), + log_level: get_env_str(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL), + structured_logging: get_env_bool(ENV_STRUCTURED_LOGGING, DEFAULT_STRUCTURED_LOGGING), + tracing_enabled: get_env_bool(ENV_TRACING_ENABLED, DEFAULT_TRACING_ENABLED), + log_failed_validations: get_env_bool(ENV_PROXY_LOG_FAILED_VALIDATIONS, DEFAULT_PROXY_LOG_FAILED_VALIDATIONS), } } - /// 加载云服务配置 + /// Loads cloud configuration from environment variables. fn load_cloud_config() -> CloudConfig { - let forced_provider_str = get_string_from_env(ENV_CLOUD_PROVIDER_FORCE, DEFAULT_CLOUD_PROVIDER_FORCE); + let forced_provider_str = get_env_str(ENV_CLOUD_PROVIDER_FORCE, DEFAULT_CLOUD_PROVIDER_FORCE); let forced_provider = if forced_provider_str.is_empty() { None } else { @@ -123,24 +122,24 @@ impl ConfigLoader { }; CloudConfig { - metadata_enabled: get_bool_from_env(ENV_CLOUD_METADATA_ENABLED, DEFAULT_CLOUD_METADATA_ENABLED), - metadata_timeout_seconds: get_u64_from_env(ENV_CLOUD_METADATA_TIMEOUT, DEFAULT_CLOUD_METADATA_TIMEOUT), - cloudflare_ips_enabled: get_bool_from_env(ENV_CLOUDFLARE_IPS_ENABLED, DEFAULT_CLOUDFLARE_IPS_ENABLED), + metadata_enabled: get_env_bool(ENV_CLOUD_METADATA_ENABLED, DEFAULT_CLOUD_METADATA_ENABLED), + metadata_timeout_seconds: get_env_u64(ENV_CLOUD_METADATA_TIMEOUT, DEFAULT_CLOUD_METADATA_TIMEOUT), + cloudflare_ips_enabled: get_env_bool(ENV_CLOUDFLARE_IPS_ENABLED, DEFAULT_CLOUDFLARE_IPS_ENABLED), forced_provider, } } - /// 加载服务器地址 + /// Loads the server binding address from environment variables. fn load_server_addr() -> SocketAddr { - let host = get_string_from_env("SERVER_HOST", "0.0.0.0"); - let port = get_usize_from_env("SERVER_PORT", 3000) as u16; + let host = get_env_str("SERVER_HOST", "0.0.0.0"); + let port = get_env_usize("SERVER_PORT", 3000) as u16; format!("{}:{}", host, port) .parse() .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 3000))) } - /// 从环境变量加载配置,如果失败则使用默认值 + /// Loads configuration from environment, falling back to defaults on failure. pub fn from_env_or_default() -> AppConfig { match Self::from_env() { Ok(config) => { @@ -154,9 +153,8 @@ impl ConfigLoader { } } - /// 创建默认配置 + /// Returns a default configuration. pub fn default_config() -> AppConfig { - // 默认可信代理配置 let proxy_config = TrustedProxyConfig::new( vec![ TrustedProxy::Single("127.0.0.1".parse().unwrap()), @@ -173,7 +171,6 @@ impl ConfigLoader { ], ); - // 默认应用配置 AppConfig::new( proxy_config, CacheConfig::default(), @@ -183,7 +180,7 @@ impl ConfigLoader { ) } - /// 打印配置摘要 + /// Prints a summary of the configuration to the log. pub fn print_summary(config: &AppConfig) { tracing::info!("=== Application Configuration ==="); tracing::info!("Server: {}", config.server_addr); diff --git a/crates/trusted-proxies/src/config/types.rs b/crates/trusted-proxies/src/config/types.rs index 72802380..c433ccf5 100644 --- a/crates/trusted-proxies/src/config/types.rs +++ b/crates/trusted-proxies/src/config/types.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Configuration type definitions +//! Configuration type definitions for the trusted proxy system. use ipnetwork::IpNetwork; use serde::{Deserialize, Serialize}; @@ -21,25 +21,26 @@ use std::time::Duration; use crate::error::ConfigError; -/// 代理验证模式 +/// Proxy validation mode defining how the proxy chain is verified. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ValidationMode { - /// 宽松模式:只要最后一个代理可信,就接受整个链 + /// Lenient mode: Accepts the entire chain as long as the last proxy is trusted. Lenient, - /// 严格模式:要求链中所有代理都可信 + /// Strict mode: Requires all proxies in the chain to be trusted. Strict, - /// 跳数验证模式:从右向左找到第一个不可信代理 + /// Hop-by-hop mode: Finds the first untrusted proxy from right to left. + /// This is the recommended mode for most production environments. HopByHop, } impl ValidationMode { - /// 从字符串解析验证模式 + /// Parses the validation mode from a string. pub fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "lenient" => Ok(Self::Lenient), "strict" => Ok(Self::Strict), - "hop_by_hop" => Ok(Self::HopByHop), + "hop_by_hop" | "hopbyhop" => Ok(Self::HopByHop), _ => Err(ConfigError::InvalidConfig(format!( "Invalid validation mode: '{}'. Must be one of: lenient, strict, hop_by_hop", s @@ -47,7 +48,7 @@ impl ValidationMode { } } - /// 转换为字符串 + /// Returns the string representation of the validation mode. pub fn as_str(&self) -> &'static str { match self { Self::Lenient => "lenient", @@ -63,17 +64,17 @@ impl Default for ValidationMode { } } -/// 可信代理类型 -#[derive(Debug, Clone)] +/// Represents a trusted proxy entry, which can be a single IP or a CIDR range. +#[derive(Debug, Clone, PartialEq, Eq)] pub enum TrustedProxy { - /// 单个 IP 地址 + /// A single IP address. Single(IpAddr), - /// IP 地址段 (CIDR 表示法) + /// An IP network range (CIDR notation). Cidr(IpNetwork), } impl TrustedProxy { - /// 检查 IP 是否匹配此代理配置 + /// Checks if the given IP address matches this proxy configuration. pub fn contains(&self, ip: &IpAddr) -> bool { match self { Self::Single(proxy_ip) => ip == proxy_ip, @@ -81,7 +82,7 @@ impl TrustedProxy { } } - /// 转换为字符串表示 + /// Returns the string representation of the proxy entry. pub fn to_string(&self) -> String { match self { Self::Single(ip) => ip.to_string(), @@ -90,25 +91,25 @@ impl TrustedProxy { } } -/// 可信代理配置 +/// Configuration for trusted proxies and validation logic. #[derive(Debug, Clone)] pub struct TrustedProxyConfig { - /// 代理列表 + /// List of trusted proxy entries. pub proxies: Vec, - /// 验证模式 + /// The validation mode to use for verifying proxy chains. pub validation_mode: ValidationMode, - /// 是否启用 RFC 7239 Forwarded 头部 + /// Whether to enable RFC 7239 "Forwarded" header support. pub enable_rfc7239: bool, - /// 最大代理跳数 + /// Maximum allowed proxy hops in the chain. pub max_hops: usize, - /// 是否启用链连续性检查 + /// Whether to enable continuity checks for the proxy chain. pub enable_chain_continuity_check: bool, - /// 私有网络范围 + /// Private network ranges that should be treated with caution. pub private_networks: Vec, } impl TrustedProxyConfig { - /// 创建新配置 + /// Creates a new trusted proxy configuration. pub fn new( proxies: Vec, validation_mode: ValidationMode, @@ -127,23 +128,23 @@ impl TrustedProxyConfig { } } - /// 检查 SocketAddr 是否来自可信代理 + /// Checks if a SocketAddr originates from a trusted proxy. pub fn is_trusted(&self, addr: &SocketAddr) -> bool { let ip = addr.ip(); self.proxies.iter().any(|proxy| proxy.contains(&ip)) } - /// 检查 IP 是否在私有网络范围内 + /// Checks if an IP address belongs to a private network range. pub fn is_private_network(&self, ip: &IpAddr) -> bool { self.private_networks.iter().any(|network| network.contains(*ip)) } - /// 获取所有网络范围的字符串表示(用于调试) + /// Returns a list of all network strings for debugging purposes. pub fn get_network_strings(&self) -> Vec { self.proxies.iter().map(|p| p.to_string()).collect() } - /// 获取配置摘要 + /// Returns a summary of the configuration. pub fn summary(&self) -> String { format!( "TrustedProxyConfig {{ proxies: {}, mode: {}, max_hops: {} }}", @@ -154,14 +155,14 @@ impl TrustedProxyConfig { } } -/// 缓存配置 +/// Configuration for the internal caching mechanism. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CacheConfig { - /// 缓存容量 + /// Maximum number of entries in the cache. pub capacity: usize, - /// 缓存 TTL(秒) + /// Time-to-live for cache entries in seconds. pub ttl_seconds: u64, - /// 缓存清理间隔(秒) + /// Interval for cache cleanup in seconds. pub cleanup_interval_seconds: u64, } @@ -176,29 +177,29 @@ impl Default for CacheConfig { } impl CacheConfig { - /// 获取缓存 TTL 时长 + /// Returns the TTL as a Duration. pub fn ttl_duration(&self) -> Duration { Duration::from_secs(self.ttl_seconds) } - /// 获取缓存清理间隔时长 + /// Returns the cleanup interval as a Duration. pub fn cleanup_interval(&self) -> Duration { Duration::from_secs(self.cleanup_interval_seconds) } } -/// 监控配置 +/// Configuration for monitoring and observability. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MonitoringConfig { - /// 是否启用监控指标 + /// Whether to enable Prometheus metrics. pub metrics_enabled: bool, - /// 日志级别 + /// The logging level (e.g., "info", "debug"). pub log_level: String, - /// 是否启用结构化日志 + /// Whether to use structured JSON logging. pub structured_logging: bool, - /// 是否启用请求追踪 + /// Whether to enable distributed tracing. pub tracing_enabled: bool, - /// 是否记录验证失败的请求 + /// Whether to log detailed information about failed validations. pub log_failed_validations: bool, } @@ -214,16 +215,16 @@ impl Default for MonitoringConfig { } } -/// 云服务集成配置 +/// Configuration for cloud provider integration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CloudConfig { - /// 是否启用云元数据获取 + /// Whether to enable automatic cloud metadata discovery. pub metadata_enabled: bool, - /// 云元数据获取超时(秒) + /// Timeout for cloud metadata requests in seconds. pub metadata_timeout_seconds: u64, - /// 是否启用 Cloudflare IP 范围 + /// Whether to automatically include Cloudflare IP ranges. pub cloudflare_ips_enabled: bool, - /// 强制指定的云服务商 + /// Optionally force a specific cloud provider. pub forced_provider: Option, } @@ -239,29 +240,29 @@ impl Default for CloudConfig { } impl CloudConfig { - /// 获取元数据获取超时时长 + /// Returns the metadata timeout as a Duration. pub fn metadata_timeout(&self) -> Duration { Duration::from_secs(self.metadata_timeout_seconds) } } -/// 完整的应用配置 +/// Complete application configuration. #[derive(Debug, Clone)] pub struct AppConfig { - /// 代理配置 + /// Trusted proxy settings. pub proxy: TrustedProxyConfig, - /// 缓存配置 + /// Cache settings. pub cache: CacheConfig, - /// 监控配置 + /// Monitoring and observability settings. pub monitoring: MonitoringConfig, - /// 云服务配置 + /// Cloud integration settings. pub cloud: CloudConfig, - /// 服务器绑定地址 + /// The address the server should bind to. pub server_addr: SocketAddr, } impl AppConfig { - /// 创建应用配置 + /// Creates a new application configuration. pub fn new( proxy: TrustedProxyConfig, cache: CacheConfig, @@ -278,7 +279,7 @@ impl AppConfig { } } - /// 获取配置摘要 + /// Returns a summary of the application configuration. pub fn summary(&self) -> String { format!( "AppConfig {{\n\ diff --git a/crates/trusted-proxies/src/error/config.rs b/crates/trusted-proxies/src/error/config.rs index 6c91aa02..f265dfbd 100644 --- a/crates/trusted-proxies/src/error/config.rs +++ b/crates/trusted-proxies/src/error/config.rs @@ -12,42 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Configuration error types +//! Configuration error types for the trusted proxy system. use std::net::AddrParseError; -/// 配置错误类型 +/// Errors related to application configuration. #[derive(Debug, thiserror::Error)] pub enum ConfigError { - /// 环境变量缺失 + /// Required environment variable is missing. #[error("Missing environment variable: {0}")] MissingEnvVar(String), - /// 环境变量解析失败 + /// Environment variable exists but could not be parsed. #[error("Failed to parse environment variable {0}: {1}")] EnvParseError(String, String), - /// 无效的配置值 + /// A configuration value is logically invalid. #[error("Invalid configuration value for {0}: {1}")] InvalidValue(String, String), - /// 无效的 IP 地址或网络 + /// An IP address or CIDR range is malformed. #[error("Invalid IP address or network: {0}")] InvalidIp(String), - /// 配置验证失败 + /// Configuration failed overall validation. #[error("Configuration validation failed: {0}")] ValidationFailed(String), - /// 配置冲突 + /// Two or more configuration settings are in conflict. #[error("Configuration conflict: {0}")] Conflict(String), - /// 配置文件错误 + /// Error reading or parsing a configuration file. #[error("Config file error: {0}")] FileError(String), - /// 无效的配置 + /// General invalid configuration error. #[error("Invalid config: {0}")] InvalidConfig(String), } @@ -65,17 +65,17 @@ impl From for ConfigError { } impl ConfigError { - /// 创建环境变量缺失错误 + /// Creates a `MissingEnvVar` error. pub fn missing_env_var(key: &str) -> Self { Self::MissingEnvVar(key.to_string()) } - /// 创建环境变量解析错误 + /// Creates an `EnvParseError`. pub fn env_parse(key: &str, value: &str) -> Self { Self::EnvParseError(key.to_string(), value.to_string()) } - /// 创建无效配置值错误 + /// Creates an `InvalidValue` error. pub fn invalid_value(field: &str, value: &str) -> Self { Self::InvalidValue(field.to_string(), value.to_string()) } diff --git a/crates/trusted-proxies/src/error/mod.rs b/crates/trusted-proxies/src/error/mod.rs index 4b3c7502..de72bad7 100644 --- a/crates/trusted-proxies/src/error/mod.rs +++ b/crates/trusted-proxies/src/error/mod.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Error types for the trusted proxy system +//! Error types for the trusted proxy system. mod config; mod proxy; @@ -20,55 +20,55 @@ mod proxy; pub use config::*; pub use proxy::*; -/// 统一错误类型 +/// Unified error type for the application. #[derive(Debug, thiserror::Error)] pub enum AppError { - /// 配置错误 + /// Errors related to configuration. #[error("Configuration error: {0}")] Config(#[from] ConfigError), - /// 代理验证错误 + /// Errors related to proxy validation. #[error("Proxy validation error: {0}")] Proxy(#[from] ProxyError), - /// 云服务错误 + /// Errors related to cloud service integration. #[error("Cloud service error: {0}")] Cloud(String), - /// 内部错误 + /// General internal errors. #[error("Internal error: {0}")] Internal(String), - /// IO 错误 + /// Standard I/O errors. #[error("IO error: {0}")] Io(#[from] std::io::Error), - /// HTTP 错误 + /// Errors related to HTTP requests or responses. #[error("HTTP error: {0}")] Http(String), } impl AppError { - /// 创建云服务错误 + /// Creates a new `Cloud` error. pub fn cloud(msg: impl Into) -> Self { Self::Cloud(msg.into()) } - /// 创建内部错误 + /// Creates a new `Internal` error. pub fn internal(msg: impl Into) -> Self { Self::Internal(msg.into()) } - /// 创建 HTTP 错误 + /// Creates a new `Http` error. pub fn http(msg: impl Into) -> Self { Self::Http(msg.into()) } - /// 判断错误是否可恢复 + /// Returns true if the error is considered recoverable. pub fn is_recoverable(&self) -> bool { match self { Self::Config(_) => true, - Self::Proxy(_) => true, + Self::Proxy(e) => e.is_recoverable(), Self::Cloud(_) => true, Self::Internal(_) => false, Self::Io(_) => true, @@ -77,7 +77,7 @@ impl AppError { } } -/// HTTP 响应错误类型 +/// Type alias for API error responses (Status Code, Error Message). pub type ApiError = (axum::http::StatusCode, String); impl From for ApiError { diff --git a/crates/trusted-proxies/src/error/proxy.rs b/crates/trusted-proxies/src/error/proxy.rs index 9b01ce2e..f90371c9 100644 --- a/crates/trusted-proxies/src/error/proxy.rs +++ b/crates/trusted-proxies/src/error/proxy.rs @@ -12,50 +12,50 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Proxy validation error types +//! Proxy validation error types for the trusted proxy system. use std::net::AddrParseError; -/// 代理验证错误类型 +/// Errors that can occur during proxy chain validation. #[derive(Debug, thiserror::Error)] pub enum ProxyError { - /// 无效的 X-Forwarded-For 头部 + /// The X-Forwarded-For header is malformed or contains invalid data. #[error("Invalid X-Forwarded-For header: {0}")] InvalidXForwardedFor(String), - /// 无效的 Forwarded 头部(RFC 7239) + /// The RFC 7239 Forwarded header is malformed. #[error("Invalid Forwarded header (RFC 7239): {0}")] InvalidForwardedHeader(String), - /// 代理链验证失败 + /// General failure during proxy chain validation. #[error("Proxy chain validation failed: {0}")] ChainValidationFailed(String), - /// 代理链过长 + /// The number of proxy hops exceeds the configured limit. #[error("Proxy chain too long: {0} hops (max: {1})")] ChainTooLong(usize, usize), - /// 来自不可信代理 + /// The request originated from a proxy that is not in the trusted list. #[error("Request from untrusted proxy: {0}")] UntrustedProxy(String), - /// 代理链不连续 + /// The proxy chain is not continuous (e.g., an untrusted IP is between trusted ones). #[error("Proxy chain is not continuous")] ChainNotContinuous, - /// IP 地址解析失败 + /// An IP address in the chain could not be parsed. #[error("Failed to parse IP address: {0}")] IpParseError(String), - /// 头部解析失败 + /// A header value could not be parsed as a string. #[error("Failed to parse header: {0}")] HeaderParseError(String), - /// 验证超时 + /// Validation took too long and timed out. #[error("Validation timeout")] Timeout, - /// 内部验证错误 + /// An unexpected internal error occurred during validation. #[error("Internal validation error: {0}")] Internal(String), } @@ -67,40 +67,41 @@ impl From for ProxyError { } impl ProxyError { - /// 创建无效 X-Forwarded-For 头部错误 + /// Creates an `InvalidXForwardedFor` error. pub fn invalid_xff(msg: impl Into) -> Self { Self::InvalidXForwardedFor(msg.into()) } - /// 创建无效 Forwarded 头部错误 + /// Creates an `InvalidForwardedHeader` error. pub fn invalid_forwarded(msg: impl Into) -> Self { Self::InvalidForwardedHeader(msg.into()) } - /// 创建代理链验证失败错误 + /// Creates a `ChainValidationFailed` error. pub fn chain_failed(msg: impl Into) -> Self { Self::ChainValidationFailed(msg.into()) } - /// 创建来自不可信代理错误 + /// Creates an `UntrustedProxy` error. pub fn untrusted(proxy: impl Into) -> Self { Self::UntrustedProxy(proxy.into()) } - /// 创建内部验证错误 + /// Creates an `Internal` validation error. pub fn internal(msg: impl Into) -> Self { Self::Internal(msg.into()) } - /// 判断错误是否可恢复(是否应该继续处理请求) + /// Determines if the error is recoverable, meaning the request can still be processed + /// (perhaps by falling back to the direct peer IP). pub fn is_recoverable(&self) -> bool { match self { - // 这些错误通常意味着我们应该拒绝请求或使用备用 IP + // These errors typically mean we should use the direct peer IP as a fallback. Self::UntrustedProxy(_) => true, Self::ChainTooLong(_, _) => true, Self::ChainNotContinuous => true, - // 这些错误可能意味着配置问题或恶意请求 + // These errors suggest malformed requests or severe configuration issues. Self::InvalidXForwardedFor(_) => false, Self::InvalidForwardedHeader(_) => false, Self::ChainValidationFailed(_) => false, diff --git a/crates/trusted-proxies/src/lib.rs b/crates/trusted-proxies/src/lib.rs index da8e80df..824aac09 100644 --- a/crates/trusted-proxies/src/lib.rs +++ b/crates/trusted-proxies/src/lib.rs @@ -12,16 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod api; -mod cloud; -mod config; -mod error; -mod logging; -mod middleware; -mod proxy; -mod state; -mod utils; +pub mod api; +pub mod cloud; +pub mod config; +pub mod error; +pub mod logging; +pub mod middleware; +pub mod proxy; +pub mod state; +pub mod utils; +// Re-export core types for convenience pub use cloud::*; pub use config::*; +pub use middleware::{ClientInfo, TrustedProxyLayer, TrustedProxyMiddleware}; pub use proxy::*; +pub use state::AppState; diff --git a/crates/trusted-proxies/src/logging/middleware.rs b/crates/trusted-proxies/src/logging/middleware.rs index 805b75f1..eb4eaf3f 100644 --- a/crates/trusted-proxies/src/logging/middleware.rs +++ b/crates/trusted-proxies/src/logging/middleware.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Logging middleware for Axum +//! Logging middleware for the Axum web framework. use std::task::{Context, Poll}; use std::time::Instant; @@ -21,14 +21,14 @@ use uuid::Uuid; use crate::logging::Logger; -/// 请求日志中间件层 +/// Tower Layer for request logging middleware. #[derive(Clone)] pub struct RequestLoggingLayer { logger: Logger, } impl RequestLoggingLayer { - /// 创建新的日志中间件层 + /// Creates a new `RequestLoggingLayer`. pub fn new(logger: Logger) -> Self { Self { logger } } @@ -45,7 +45,7 @@ impl tower::Layer for RequestLoggingLayer { } } -/// 请求日志中间件服务 +/// Tower Service for request logging middleware. #[derive(Clone)] pub struct RequestLoggingMiddleware { inner: S, @@ -70,24 +70,22 @@ where let mut inner = self.inner.clone(); Box::pin(async move { - // 生成请求 ID + // Generate a unique request ID for correlation. let request_id = Uuid::new_v4().to_string(); - // 记录请求开始时间和日志 let start_time = Instant::now(); logger.log_request(&req, &request_id); - // 将请求 ID 添加到请求扩展中 + // Inject the request ID into the request extensions. let mut req = req; req.extensions_mut().insert(RequestId(request_id.clone())); - // 处理请求 + // Process the request. let result = inner.call(req).await; - // 计算处理时间 let duration = start_time.elapsed(); - // 记录响应 + // Log the response or error. match &result { Ok(response) => { logger.log_response(response, &request_id, duration); @@ -102,12 +100,12 @@ where } } -/// 请求 ID 包装器 +/// Wrapper for a unique request ID. #[derive(Debug, Clone)] pub struct RequestId(String); impl RequestId { - /// 获取请求 ID + /// Returns the request ID as a string slice. pub fn as_str(&self) -> &str { &self.0 } @@ -119,7 +117,7 @@ impl std::fmt::Display for RequestId { } } -/// 代理特定的日志中间件 +/// Middleware specifically for logging proxy-related information. #[derive(Clone)] pub struct ProxyLoggingMiddleware { inner: S, @@ -127,7 +125,7 @@ pub struct ProxyLoggingMiddleware { } impl ProxyLoggingMiddleware { - /// 创建新的代理日志中间件 + /// Creates a new `ProxyLoggingMiddleware`. pub fn new(inner: S, logger: Logger) -> Self { Self { inner, logger } } @@ -147,7 +145,7 @@ where } fn call(&mut self, mut req: axum::extract::Request) -> Self::Future { - // 记录代理相关信息 + // Log proxy-specific details if available. let peer_addr = req.extensions().get::().copied(); let client_info = req.extensions().get::(); @@ -155,7 +153,7 @@ where self.logger .log_info(&format!("Proxy request from {}: {}", addr, info.to_log_string()), None); - // 如果有警告,记录它们 + // Log any warnings generated during proxy validation. if !info.warnings.is_empty() { for warning in &info.warnings { self.logger.log_warning(warning, Some("proxy_validation")); @@ -167,14 +165,14 @@ where } } -/// 代理日志中间件层 +/// Tower Layer for proxy logging middleware. #[derive(Clone)] pub struct ProxyLoggingLayer { logger: Logger, } impl ProxyLoggingLayer { - /// 创建新的代理日志中间件层 + /// Creates a new `ProxyLoggingLayer`. pub fn new(logger: Logger) -> Self { Self { logger } } diff --git a/crates/trusted-proxies/src/logging/mod.rs b/crates/trusted-proxies/src/logging/mod.rs index 92dd89bb..714c3988 100644 --- a/crates/trusted-proxies/src/logging/mod.rs +++ b/crates/trusted-proxies/src/logging/mod.rs @@ -12,26 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Logging module for structured logging and middleware +//! Logging module for structured logging and observability. mod middleware; pub use middleware::*; -/// 日志配置 +/// Configuration for the logging system. #[derive(Debug, Clone)] pub struct LoggingConfig { - /// 是否启用结构化日志 + /// Whether to use structured JSON logging. pub structured: bool, - /// 日志级别 + /// The logging level (e.g., "info", "debug"). pub level: String, - /// 是否启用请求 ID + /// Whether to include a unique request ID in logs. pub enable_request_id: bool, - /// 是否记录请求体 + /// Whether to log the contents of request bodies. pub log_request_body: bool, - /// 是否记录响应体 + /// Whether to log the contents of response bodies. pub log_response_body: bool, - /// 敏感字段列表(将被脱敏) + /// List of header names that should be redacted in logs. pub sensitive_fields: Vec, } @@ -48,40 +48,31 @@ impl Default for LoggingConfig { "token".to_string(), "secret".to_string(), "authorization".to_string(), + "cookie".to_string(), + "set-cookie".to_string(), ], } } } -/// 初始化日志系统 +/// Initializes the global tracing subscriber. pub fn init_logging(config: &LoggingConfig) -> Result<(), Box> { - // 创建日志过滤器 let filter = tracing_subscriber::EnvFilter::builder() .with_default_directive(config.level.parse().unwrap_or(tracing::Level::INFO.into())) .from_env_lossy(); - // 根据配置选择日志格式 + let subscriber = tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(true) + .with_thread_ids(true) + .with_thread_names(true) + .with_file(true) + .with_line_number(true); + if config.structured { - // 结构化日志(JSON 格式) - tracing_subscriber::fmt() - .json() - .with_env_filter(filter) - .with_target(true) - .with_thread_ids(true) - .with_thread_names(true) - .with_file(true) - .with_line_number(true) - .init(); + subscriber.json().init(); } else { - // 普通文本日志 - tracing_subscriber::fmt() - .with_env_filter(filter) - .with_target(true) - .with_thread_ids(true) - .with_thread_names(true) - .with_file(true) - .with_line_number(true) - .init(); + subscriber.init(); } tracing::info!("Logging initialized with level: {}", config.level); @@ -89,19 +80,19 @@ pub fn init_logging(config: &LoggingConfig) -> Result<(), Box Self { Self { config } } - /// 记录 HTTP 请求 + /// Logs an incoming HTTP request. pub fn log_request(&self, req: &axum::http::Request, request_id: &str) { let method = req.method(); let uri = req.uri(); @@ -115,13 +106,12 @@ impl Logger { "HTTP request received" ); - // 如果启用了请求体日志记录,记录头部 if self.config.log_request_body { self.log_headers(req.headers(), "request"); } } - /// 记录 HTTP 响应 + /// Logs an outgoing HTTP response. pub fn log_response(&self, res: &axum::http::Response, request_id: &str, duration: std::time::Duration) { let status = res.status(); let version = res.version(); @@ -134,13 +124,12 @@ impl Logger { "HTTP response sent" ); - // 如果启用了响应体日志记录,记录头部 if self.config.log_response_body { self.log_headers(res.headers(), "response"); } } - /// 记录头部信息(脱敏敏感字段) + /// Logs HTTP headers, redacting sensitive information. fn log_headers(&self, headers: &axum::http::HeaderMap, header_type: &str) { let mut header_fields = std::collections::HashMap::new(); @@ -151,7 +140,6 @@ impl Logger { Err(_) => "[BINARY]".to_string(), }; - // 检查是否为敏感字段 let is_sensitive = self .config .sensitive_fields @@ -172,7 +160,7 @@ impl Logger { ); } - /// 记录错误 + /// Logs an error with optional request context. pub fn log_error(&self, error: &impl std::error::Error, request_id: Option<&str>) { if let Some(id) = request_id { tracing::error!( @@ -190,7 +178,7 @@ impl Logger { } } - /// 记录警告 + /// Logs a warning message. pub fn log_warning(&self, message: &str, context: Option<&str>) { if let Some(ctx) = context { tracing::warn!(message = %message, context = %ctx, "Warning"); @@ -199,7 +187,7 @@ impl Logger { } } - /// 记录信息 + /// Logs an informational message. pub fn log_info(&self, message: &str, context: Option<&str>) { if let Some(ctx) = context { tracing::info!(message = %message, context = %ctx, "Info"); @@ -208,7 +196,7 @@ impl Logger { } } - /// 记录调试信息 + /// Logs a debug message. pub fn log_debug(&self, message: &str, context: Option<&str>) { if let Some(ctx) = context { tracing::debug!(message = %message, context = %ctx, "Debug"); diff --git a/crates/trusted-proxies/src/main.rs b/crates/trusted-proxies/src/main.rs index 47b8c574..7ac866d6 100644 --- a/crates/trusted-proxies/src/main.rs +++ b/crates/trusted-proxies/src/main.rs @@ -12,19 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Main application entry point for the trusted proxy system +//! Main application entry point for the RustFS Trusted Proxies service. -use std::net::SocketAddr; use std::sync::Arc; -use axum::{ - extract::State, - response::{IntoResponse, Json}, - routing::get, - Router, -}; +use axum::{routing::get, Router}; use tokio::net::TcpListener; -use tracing::{error, info}; +use tracing::{info, Level}; +use tracing_subscriber::EnvFilter; mod api; mod cloud; @@ -36,98 +31,86 @@ mod state; mod utils; use api::handlers; -use config::{AppConfig, ConfigLoader}; +use config::{AppConfig, ConfigLoader, MonitoringConfig}; use error::AppError; use middleware::TrustedProxyLayer; -use proxy::metrics::{default_proxy_metrics, ProxyMetrics}; +use proxy::metrics::default_proxy_metrics; use state::AppState; #[tokio::main] async fn main() -> Result<(), AppError> { - // 加载环境变量 + // Load environment variables from .env file if present. dotenvy::dotenv().ok(); - // 从环境变量加载配置 + // Load application configuration from environment variables. let config = ConfigLoader::from_env_or_default(); - // 初始化日志 + // Initialize the logging system. init_logging(&config.monitoring)?; - // 打印配置摘要 + // Print a summary of the loaded configuration. ConfigLoader::print_summary(&config); - // 初始化指标收集器 + // Initialize metrics collector if enabled. let metrics = if config.monitoring.metrics_enabled { - let metrics = default_proxy_metrics(true); - metrics.print_summary(); - Some(metrics) + let m = default_proxy_metrics(true); + m.print_summary(); + Some(m) } else { None }; - // 创建应用状态 + // Create shared application state. let state = AppState { - config: Arc::new(config), + config: Arc::new(config.clone()), metrics: metrics.clone(), }; - // 创建可信代理中间件层 - let proxy_layer = TrustedProxyLayer::enabled(state.clone().config.proxy.clone(), metrics); + // Initialize the trusted proxy middleware layer. + let proxy_layer = TrustedProxyLayer::enabled(config.proxy.clone(), metrics); - // 创建路由 + // Build the Axum application router. let app = Router::new() - // 健康检查端点 .route("/health", get(handlers::health_check)) - // 配置查看端点 .route("/config", get(handlers::show_config)) - // 客户端信息端点 .route("/client-info", get(handlers::client_info)) - // 代理测试端点 .route("/proxy-test", get(handlers::proxy_test)) - // 指标端点(如果启用) .route("/metrics", get(handlers::metrics)) - // 添加应用状态 - .with_state(state.clone()) - // 添加可信代理中间件 + .with_state(state) .layer(proxy_layer) - // 添加追踪中间件(如果启用) .layer(tower_http::trace::TraceLayer::new_for_http()) - // 添加 CORS 中间件 .layer(tower_http::cors::CorsLayer::permissive()) - // 添加压缩中间件 .layer(tower_http::compression::CompressionLayer::new()); - // 启动服务器 - let addr = state.config.server_addr; - let listener = TcpListener::bind(addr).await.map_err(|e| AppError::Io(e))?; + // Bind the TCP listener and start the server. + let addr = config.server_addr; + let listener = TcpListener::bind(addr).await?; info!("Server listening on http://{}", addr); info!("Available endpoints:"); - info!(" GET /health - Health check"); - info!(" GET /config - Show configuration"); - info!(" GET /client-info - Show client information"); - info!(" GET /proxy-test - Test proxy headers"); + info!(" GET /health - Service health check"); + info!(" GET /config - Current configuration summary"); + info!(" GET /client-info - Extracted client information"); + info!(" GET /proxy-test - Debugging endpoint for proxy headers"); info!(" GET /metrics - Prometheus metrics (if enabled)"); - axum::serve(listener, app).await.map_err(|e| AppError::Io(e))?; + axum::serve(listener, app).await?; Ok(()) } -/// 初始化日志系统 -fn init_logging(monitoring_config: &config::MonitoringConfig) -> Result<(), AppError> { - // 创建日志过滤器 - let filter = tracing_subscriber::EnvFilter::builder() - .with_default_directive(monitoring_config.log_level.parse().unwrap_or(tracing::Level::INFO.into())) +/// Initializes the tracing subscriber for logging. +fn init_logging(monitoring_config: &MonitoringConfig) -> Result<(), AppError> { + let filter = EnvFilter::builder() + .with_default_directive(monitoring_config.log_level.parse().unwrap_or(Level::INFO.into())) .from_env_lossy(); - // 根据配置选择日志格式 + let subscriber = tracing_subscriber::fmt().with_env_filter(filter); + if monitoring_config.structured_logging { - // 结构化日志(JSON 格式) - tracing_subscriber::fmt().json().with_env_filter(filter).init(); + subscriber.json().init(); } else { - // 普通文本日志 - tracing_subscriber::fmt().with_env_filter(filter).init(); + subscriber.init(); } info!("Logging initialized with level: {}", monitoring_config.log_level); diff --git a/crates/trusted-proxies/src/middleware/layer.rs b/crates/trusted-proxies/src/middleware/layer.rs index 6c74f11f..6bb99117 100644 --- a/crates/trusted-proxies/src/middleware/layer.rs +++ b/crates/trusted-proxies/src/middleware/layer.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Tower layer implementation for trusted proxy middleware +//! Tower layer implementation for the trusted proxy middleware. use std::sync::Arc; use tower::Layer; @@ -22,17 +22,17 @@ use crate::middleware::TrustedProxyMiddleware; use crate::proxy::ProxyMetrics; use crate::proxy::ProxyValidator; -/// 可信代理中间件层 +/// Tower Layer for the trusted proxy middleware. #[derive(Clone)] pub struct TrustedProxyLayer { - /// 代理验证器 + /// The validator used to verify proxy chains. pub(crate) validator: Arc, - /// 是否启用中间件 + /// Whether the middleware is enabled. pub(crate) enabled: bool, } impl TrustedProxyLayer { - /// 创建新的中间件层 + /// Creates a new `TrustedProxyLayer`. pub fn new(config: TrustedProxyConfig, metrics: Option, enabled: bool) -> Self { let validator = ProxyValidator::new(config, metrics); @@ -42,12 +42,12 @@ impl TrustedProxyLayer { } } - /// 创建启用的中间件层 + /// Creates a new `TrustedProxyLayer` that is enabled by default. pub fn enabled(config: TrustedProxyConfig, metrics: Option) -> Self { Self::new(config, metrics, true) } - /// 创建禁用的中间件层 + /// Creates a new `TrustedProxyLayer` that is disabled. pub fn disabled() -> Self { Self::new( TrustedProxyConfig::new(Vec::new(), crate::config::ValidationMode::Lenient, true, 10, true, Vec::new()), @@ -56,7 +56,7 @@ impl TrustedProxyLayer { ) } - /// 检查中间件是否启用 + /// Returns true if the middleware is enabled. pub fn is_enabled(&self) -> bool { self.enabled } diff --git a/crates/trusted-proxies/src/middleware/service.rs b/crates/trusted-proxies/src/middleware/service.rs index dbbe586a..4ca51751 100644 --- a/crates/trusted-proxies/src/middleware/service.rs +++ b/crates/trusted-proxies/src/middleware/service.rs @@ -12,33 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Tower service implementation for trusted proxy middleware +//! Tower service implementation for the trusted proxy middleware. use std::sync::Arc; -use std::task::{ready, Context, Poll}; +use std::task::{Context, Poll}; use axum::extract::Request; use axum::response::Response; use tower::Service; use tracing::{debug, instrument, Span}; -use crate::error::ProxyError; use crate::middleware::layer::TrustedProxyLayer; use crate::proxy::{ClientInfo, ProxyValidator}; -/// 可信代理中间件服务 +/// Tower Service for the trusted proxy middleware. #[derive(Clone)] pub struct TrustedProxyMiddleware { - /// 内部服务 + /// The inner service being wrapped. inner: S, - /// 代理验证器 + /// The validator used to verify proxy chains. validator: Arc, - /// 是否启用中间件 + /// Whether the middleware is enabled. enabled: bool, } impl TrustedProxyMiddleware { - /// 创建新的中间件服务 + /// Creates a new `TrustedProxyMiddleware`. pub fn new(inner: S, validator: Arc, enabled: bool) -> Self { Self { inner, @@ -47,7 +46,7 @@ impl TrustedProxyMiddleware { } } - /// 从层创建中间件服务 + /// Creates a new `TrustedProxyMiddleware` from a `TrustedProxyLayer`. pub fn from_layer(inner: S, layer: &TrustedProxyLayer) -> Self { Self::new(inner, layer.validator.clone(), layer.enabled) } @@ -79,57 +78,51 @@ where fn call(&mut self, mut req: Request) -> Self::Future { let span = Span::current(); - // 如果中间件未启用,直接传递请求 + // If the middleware is disabled, pass the request through immediately. if !self.enabled { debug!("Trusted proxy middleware is disabled"); return self.inner.call(req); } - // 记录请求开始时间 let start_time = std::time::Instant::now(); - // 提取对端地址 + // Extract the direct peer address from the request extensions. let peer_addr = req.extensions().get::().copied(); - // 为 span 添加字段 if let Some(addr) = peer_addr { span.record("peer.addr", addr.to_string()); } - // 验证请求并提取客户端信息 + // Validate the request and extract client information. match self.validator.validate_request(peer_addr, req.headers()) { Ok(client_info) => { - // 记录客户端信息到 span span.record("client.ip", client_info.real_ip.to_string()); span.record("client.trusted", client_info.is_from_trusted_proxy); span.record("client.hops", client_info.proxy_hops as i64); - // 将客户端信息存入请求扩展 + // Insert the verified client info into the request extensions. req.extensions_mut().insert(client_info); - // 记录验证成功 let duration = start_time.elapsed(); debug!("Proxy validation successful in {:?}", duration); } Err(err) => { - // 记录验证失败 span.record("error", true); span.record("error.message", err.to_string()); - // 如果是可恢复的错误,创建默认的客户端信息 + // If the error is recoverable, fallback to a direct connection info. if err.is_recoverable() { let client_info = ClientInfo::direct( peer_addr.unwrap_or_else(|| std::net::SocketAddr::new(std::net::IpAddr::from([0, 0, 0, 0]), 0)), ); req.extensions_mut().insert(client_info); } else { - // 对于不可恢复的错误,记录警告 debug!("Unrecoverable proxy validation error: {}", err); } } } - // 调用内部服务 + // Call the inner service. self.inner.call(req) } } diff --git a/crates/trusted-proxies/src/proxy/cache.rs b/crates/trusted-proxies/src/proxy/cache.rs index 07ae6fc5..3a201073 100644 --- a/crates/trusted-proxies/src/proxy/cache.rs +++ b/crates/trusted-proxies/src/proxy/cache.rs @@ -12,226 +12,76 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Cache implementation for proxy validation +//! High-performance cache implementation for proxy validation results using Moka. -use metrics::{counter, gauge, histogram}; -use parking_lot::RwLock; -use std::collections::HashMap; +use moka::future::Cache; use std::net::IpAddr; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use std::time::Duration; -/// 缓存条目 -#[derive(Debug, Clone)] -struct CacheEntry { - /// 是否可信 - is_trusted: bool, - /// 缓存时间 - cached_at: Instant, - /// 过期时间 - expires_at: Instant, -} - -/// IP 验证缓存 +/// Cache for storing IP validation results. #[derive(Debug, Clone)] pub struct IpValidationCache { - /// 缓存存储 - cache: Arc>>, - /// 最大容量 - capacity: usize, - /// 默认 TTL - default_ttl: Duration, - /// 是否启用 + /// The underlying Moka cache. + cache: Cache, + /// Whether the cache is enabled. enabled: bool, } impl IpValidationCache { - /// 创建新的缓存 - pub fn new(capacity: usize, default_ttl: Duration, enabled: bool) -> Self { - Self { - cache: Arc::new(RwLock::new(HashMap::with_capacity(capacity))), - capacity, - default_ttl, - enabled, - } + /// Creates a new `IpValidationCache` using Moka. + pub fn new(capacity: usize, ttl: Duration, enabled: bool) -> Self { + let cache = Cache::builder() + .max_capacity(capacity as u64) + .time_to_live(ttl) + .build(); + + Self { cache, enabled } } - /// 检查 IP 是否可信(带缓存) - pub fn is_trusted(&self, ip: &IpAddr, validator: impl FnOnce(&IpAddr) -> bool) -> bool { - // 如果缓存未启用,直接验证 + /// Checks if an IP is trusted, using the cache if available. + pub async fn is_trusted(&self, ip: &IpAddr, validator: impl FnOnce(&IpAddr) -> bool) -> bool { if !self.enabled { return validator(ip); } - let now = Instant::now(); - - // 检查缓存 - { - let cache = self.cache.read(); - if let Some(entry) = cache.get(ip) { - if now < entry.expires_at { - // 缓存命中 - counter!("proxy.cache.hits").increment(1); - return entry.is_trusted; - } - } + // Attempt to get the result from cache. + if let Some(is_trusted) = self.cache.get(ip).await { + metrics::counter!("proxy.cache.hits").increment(1); + return is_trusted; } - // 缓存未命中 - counter!("proxy.cache.misses").increment(1); - - // 验证 IP + // Cache miss: perform validation and update cache. + metrics::counter!("proxy.cache.misses").increment(1); let is_trusted = validator(ip); - - // 更新缓存 - self.update_cache(*ip, is_trusted, now); + self.cache.insert(*ip, is_trusted).await; is_trusted } - /// 更新缓存 - fn update_cache(&self, ip: IpAddr, is_trusted: bool, now: Instant) { - let mut cache = self.cache.write(); - - // 检查是否需要清理(如果达到容量限制) - if cache.len() >= self.capacity { - self.cleanup_expired(&mut cache, now); - - // 如果仍然满,删除最旧的条目 - if cache.len() >= self.capacity { - self.evict_oldest(&mut cache); - } - } - - // 添加新条目 - let entry = CacheEntry { - is_trusted, - cached_at: now, - expires_at: now + self.default_ttl, - }; - - cache.insert(ip, entry); - - // 更新指标 - gauge!("proxy.cache.size").set(cache.len() as f64); + /// Clears all entries from the cache. + pub async fn clear(&self) { + self.cache.invalidate_all().await; + metrics::gauge!("proxy.cache.size").set(0.0); } - /// 清理过期条目 - fn cleanup_expired(&self, cache: &mut HashMap, now: Instant) { - let expired_keys: Vec<_> = cache - .iter() - .filter(|(_, entry)| now >= entry.expires_at) - .map(|(ip, _)| *ip) - .collect(); - - for key in expired_keys.clone() { - cache.remove(&key); - } - - if !expired_keys.is_empty() { - counter!("proxy.cache.evictions").increment(expired_keys.len() as u64); - } - } - - /// 淘汰最旧的条目 - fn evict_oldest(&self, cache: &mut HashMap) { - if let Some(oldest_key) = cache.iter().min_by_key(|(_, entry)| entry.cached_at).map(|(ip, _)| *ip) { - cache.remove(&oldest_key); - counter!("proxy.cache.evictions").increment(1); - } - } - - /// 清空缓存 - pub fn clear(&self) { - let mut cache = self.cache.write(); - cache.clear(); - gauge!("proxy.cache.size").set(0.00); - } - - /// 获取缓存统计信息 + /// Returns statistics about the current state of the cache. pub fn stats(&self) -> CacheStats { - let cache = self.cache.read(); - - let mut oldest = Instant::now(); - let mut newest = Instant::now(); - let mut expired_count = 0; - - let now = Instant::now(); - for entry in cache.values() { - if entry.cached_at < oldest { - oldest = entry.cached_at; - } - if entry.cached_at > newest { - newest = entry.cached_at; - } - if now >= entry.expires_at { - expired_count += 1; - } - } + let entry_count = self.cache.entry_count(); CacheStats { - size: cache.len(), - capacity: self.capacity, - expired_count, - oldest_age: now.duration_since(oldest), - newest_age: now.duration_since(newest), + size: entry_count as usize, + // Moka doesn't expose max_capacity directly in a simple way after build, + // but we can track it if needed. + capacity: 0, } } - - /// 定期清理任务 - pub async fn cleanup_task(&self, interval: Duration) { - let mut interval_timer = tokio::time::interval(interval); - - loop { - interval_timer.tick().await; - self.cleanup(); - } - } - - /// 执行清理 - fn cleanup(&self) { - let now = Instant::now(); - let mut cache = self.cache.write(); - self.cleanup_expired(&mut cache, now); - - // 记录清理后的指标 - gauge!("proxy.cache.size").set(cache.len() as f64); - } } -/// 缓存统计信息 +/// Statistics about the IP validation cache. #[derive(Debug, Clone)] pub struct CacheStats { - /// 当前缓存大小 + /// Current number of entries in the cache. pub size: usize, - /// 缓存容量 + /// Maximum capacity of the cache. pub capacity: usize, - /// 过期条目数量 - pub expired_count: usize, - /// 最旧条目的年龄 - pub oldest_age: Duration, - /// 最新条目的年龄 - pub newest_age: Duration, -} - -impl CacheStats { - /// 获取缓存使用率 - pub fn usage_percentage(&self) -> f64 { - if self.capacity == 0 { - 0.0 - } else { - (self.size as f64 / self.capacity as f64) * 100.0 - } - } - - /// 获取命中率(需要外部跟踪命中/未命中) - pub fn hit_rate(&self, hits: u64, misses: u64) -> f64 { - let total = hits + misses; - if total == 0 { - 0.0 - } else { - (hits as f64 / total as f64) * 100.0 - } - } } diff --git a/crates/trusted-proxies/src/proxy/chain.rs b/crates/trusted-proxies/src/proxy/chain.rs index 92a5edee..7250a08f 100644 --- a/crates/trusted-proxies/src/proxy/chain.rs +++ b/crates/trusted-proxies/src/proxy/chain.rs @@ -12,48 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Proxy chain analysis and validation - -use std::collections::HashSet; -use std::net::IpAddr; - -use axum::http::HeaderMap; -use tracing::{debug, trace}; +//! Proxy chain analysis and validation logic. use crate::config::{TrustedProxyConfig, ValidationMode}; use crate::error::ProxyError; use crate::utils::ip::is_valid_ip_address; +use axum::http::HeaderMap; +use std::collections::HashSet; +use std::net::IpAddr; +use tracing::trace; -/// 代理链分析结果 +/// Result of analyzing a proxy chain. #[derive(Debug, Clone)] pub struct ChainAnalysis { - /// 客户端真实 IP + /// The identified real client IP address. pub client_ip: IpAddr, - /// 已验证的代理跳数 + /// The number of validated proxy hops. pub hops: usize, - /// 是否连续 + /// Whether the proxy chain is continuous and trusted. pub is_continuous: bool, - /// 警告信息 + /// List of warnings generated during analysis. pub warnings: Vec, - /// 使用的验证模式 + /// The validation mode used for analysis. pub validation_mode: ValidationMode, - /// 可信代理部分 + /// The portion of the chain that consists of trusted proxies. pub trusted_chain: Vec, } -/// 代理链分析器 +/// Analyzer for verifying the integrity of proxy chains. #[derive(Debug, Clone)] pub struct ProxyChainAnalyzer { - /// 代理配置 + /// Configuration for trusted proxies. config: TrustedProxyConfig, - /// 已验证的可信代理 IP 缓存(用于快速查找) + /// Cache of trusted IP addresses for fast lookup. trusted_ip_cache: HashSet, } impl ProxyChainAnalyzer { - /// 创建新的代理链分析器 + /// Creates a new `ProxyChainAnalyzer`. pub fn new(config: TrustedProxyConfig) -> Self { - // 构建可信 IP 缓存 let mut trusted_ip_cache = HashSet::new(); for proxy in &config.proxies { @@ -62,14 +59,10 @@ impl ProxyChainAnalyzer { trusted_ip_cache.insert(*ip); } crate::config::TrustedProxy::Cidr(network) => { - // 对于小网络,可以缓存所有 IP - // 这里我们只缓存/24及更小的网络的前几个IP作为示例 - // 实际生产环境中可能需要更复杂的缓存策略 - if network.prefix() >= 24 && network.prefix() <= 30 { - // 对于小网络,我们可以缓存网络地址和广播地址之间的几个 IP - // 这里简化处理,只缓存网络地址 - if let Some(first_ip) = network.iter().next() { - trusted_ip_cache.insert(first_ip); + // For small networks, cache all IPs to speed up lookups. + if network.prefix() >= 24 { + for ip in network.iter() { + trusted_ip_cache.insert(ip); } } } @@ -82,7 +75,7 @@ impl ProxyChainAnalyzer { } } - /// 分析代理链 + /// Analyzes a proxy chain to identify the real client IP and verify trust. pub fn analyze_chain( &self, proxy_chain: &[IpAddr], @@ -91,33 +84,33 @@ impl ProxyChainAnalyzer { ) -> Result { trace!("Analyzing proxy chain: {:?} with current proxy: {}", proxy_chain, current_proxy_ip); - // 验证 IP 地址 + // Validate all IP addresses in the chain. self.validate_ip_addresses(proxy_chain)?; - // 构建完整链(包括当前代理) + // Construct the full chain including the direct peer. let mut full_chain = proxy_chain.to_vec(); full_chain.push(current_proxy_ip); - // 根据验证模式分析链 + // Analyze the chain based on the configured validation mode. let (client_ip, trusted_chain, hops) = match self.config.validation_mode { ValidationMode::Lenient => self.analyze_lenient(&full_chain), ValidationMode::Strict => self.analyze_strict(&full_chain)?, ValidationMode::HopByHop => self.analyze_hop_by_hop(&full_chain), }; - // 检查链连续性 + // Check for chain continuity if enabled. let is_continuous = if self.config.enable_chain_continuity_check { self.check_chain_continuity(&full_chain, &trusted_chain) } else { true }; - // 收集警告 + // Collect any warnings. let warnings = self.collect_warnings(&full_chain, &trusted_chain, headers); - // 验证客户端 IP + // Final validation of the identified client IP. if !is_valid_ip_address(&client_ip) { - return Err(ProxyError::internal(format!("Invalid client IP: {}", client_ip))); + return Err(ProxyError::internal(format!("Invalid client IP identified: {}", client_ip))); } Ok(ChainAnalysis { @@ -130,33 +123,29 @@ impl ProxyChainAnalyzer { }) } - /// 宽松模式分析:只要最后一个代理可信,就接受整个链 + /// Lenient mode: Accepts the entire chain if the last proxy is trusted. fn analyze_lenient(&self, chain: &[IpAddr]) -> (IpAddr, Vec, usize) { if chain.is_empty() { return (IpAddr::from([0, 0, 0, 0]), Vec::new(), 0); } - // 检查最后一个代理是否可信 if let Some(last_proxy) = chain.last() { if self.is_ip_trusted(last_proxy) { - // 整个链都可信 let client_ip = chain.first().copied().unwrap_or(*last_proxy); return (client_ip, chain.to_vec(), chain.len()); } } - // 如果最后一个代理不可信,使用链中第一个 IP 作为客户端 let client_ip = chain.first().copied().unwrap_or(IpAddr::from([0, 0, 0, 0])); (client_ip, Vec::new(), 0) } - /// 严格模式分析:要求链中所有代理都可信 + /// Strict mode: Requires every IP in the chain to be trusted. fn analyze_strict(&self, chain: &[IpAddr]) -> Result<(IpAddr, Vec, usize), ProxyError> { if chain.is_empty() { return Ok((IpAddr::from([0, 0, 0, 0]), Vec::new(), 0)); } - // 检查每个代理是否都可信 for (i, ip) in chain.iter().enumerate() { if !self.is_ip_trusted(ip) { return Err(ProxyError::chain_failed(format!("Proxy at position {} ({}) is not trusted", i, ip))); @@ -167,7 +156,7 @@ impl ProxyChainAnalyzer { Ok((client_ip, chain.to_vec(), chain.len())) } - /// 跳数模式分析:从右向左找到第一个不可信代理 + /// Hop-by-hop mode: Traverses the chain from right to left to find the first untrusted IP. fn analyze_hop_by_hop(&self, chain: &[IpAddr]) -> (IpAddr, Vec, usize) { if chain.is_empty() { return (IpAddr::from([0, 0, 0, 0]), Vec::new(), 0); @@ -176,28 +165,24 @@ impl ProxyChainAnalyzer { let mut trusted_chain = Vec::new(); let mut validated_hops = 0; - // 从右向左遍历(从离我们最近的代理开始) + // Traverse from the most recent proxy back towards the client. for ip in chain.iter().rev() { if self.is_ip_trusted(ip) { trusted_chain.insert(0, *ip); validated_hops += 1; } else { - // 找到第一个不可信代理,停止遍历 break; } } if trusted_chain.is_empty() { - // 没有可信代理,使用链的最后一个 IP let client_ip = *chain.last().unwrap(); (client_ip, vec![client_ip], 0) } else { - // 客户端 IP 是可信链的第一个 IP 之前的那个 IP let client_ip_index = chain.len().saturating_sub(trusted_chain.len()); let client_ip = if client_ip_index > 0 { chain[client_ip_index - 1] } else { - // 如果整个链都可信,使用第一个 IP chain[0] }; @@ -205,13 +190,12 @@ impl ProxyChainAnalyzer { } } - /// 检查链连续性 + /// Verifies that the trusted portion of the chain is a continuous suffix of the full chain. fn check_chain_continuity(&self, full_chain: &[IpAddr], trusted_chain: &[IpAddr]) -> bool { if full_chain.len() <= 1 || trusted_chain.is_empty() { return true; } - // 可信链应该是完整链的尾部连续部分 if trusted_chain.len() > full_chain.len() { return false; } @@ -220,14 +204,13 @@ impl ProxyChainAnalyzer { expected_tail == trusted_chain } - /// 验证 IP 地址 + /// Validates that IP addresses are not unspecified, multicast, or otherwise invalid. fn validate_ip_addresses(&self, chain: &[IpAddr]) -> Result<(), ProxyError> { for ip in chain { if !is_valid_ip_address(ip) { return Err(ProxyError::IpParseError(format!("Invalid IP address in chain: {}", ip))); } - // 检查是否为特殊地址 if ip.is_unspecified() { return Err(ProxyError::invalid_xff("IP address cannot be unspecified (0.0.0.0 or ::)")); } @@ -240,42 +223,35 @@ impl ProxyChainAnalyzer { Ok(()) } - /// 检查 IP 是否可信 + /// Checks if an IP address is trusted based on the configuration. fn is_ip_trusted(&self, ip: &IpAddr) -> bool { - // 首先检查缓存 if self.trusted_ip_cache.contains(ip) { return true; } - // 然后检查配置中的代理 self.config.proxies.iter().any(|proxy| proxy.contains(ip)) } - /// 收集警告信息 + /// Collects warnings about potential issues in the proxy chain. fn collect_warnings(&self, full_chain: &[IpAddr], trusted_chain: &[IpAddr], headers: &HeaderMap) -> Vec { let mut warnings = Vec::new(); - // 检查代理链长度 if full_chain.len() > self.config.max_hops { warnings.push(format!( - "Proxy chain length ({}) exceeds recommended maximum ({})", + "Proxy chain length ({}) exceeds configured maximum ({})", full_chain.len(), self.config.max_hops )); } - // 检查是否缺少必要的头部 - if trusted_chain.len() > 0 { - if !headers.contains_key("x-forwarded-for") && !headers.contains_key("forwarded") { - warnings.push("No proxy headers found for trusted proxy request".to_string()); - } + if !trusted_chain.is_empty() && !headers.contains_key("x-forwarded-for") && !headers.contains_key("forwarded") { + warnings.push("No proxy headers found for request from trusted proxy".to_string()); } - // 检查是否有重复的 IP let mut seen_ips = HashSet::new(); for ip in full_chain { if !seen_ips.insert(ip) { - warnings.push(format!("Duplicate IP in proxy chain: {}", ip)); + warnings.push(format!("Duplicate IP address detected in proxy chain: {}", ip)); break; } } diff --git a/crates/trusted-proxies/src/proxy/metrics.rs b/crates/trusted-proxies/src/proxy/metrics.rs index 99b5e633..af801c41 100644 --- a/crates/trusted-proxies/src/proxy/metrics.rs +++ b/crates/trusted-proxies/src/proxy/metrics.rs @@ -12,67 +12,58 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Metrics and monitoring for proxy validation - -use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; -use std::time::Duration; -use tracing::{info, warn}; +//! Metrics and monitoring for proxy validation performance and results. use crate::config::ValidationMode; use crate::error::ProxyError; +use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; +use std::time::Duration; +use tracing::info; -/// 代理验证指标 +/// Collector for proxy validation metrics. #[derive(Debug, Clone)] pub struct ProxyMetrics { - /// 是否启用指标 + /// Whether metrics collection is enabled. enabled: bool, - /// 应用名称(用于指标标签) + /// Application name used as a label for metrics. app_name: String, } impl ProxyMetrics { - /// 创建新的指标收集器 + /// Creates a new `ProxyMetrics` collector. pub fn new(app_name: &str, enabled: bool) -> Self { let metrics = Self { enabled, app_name: app_name.to_string(), }; - // 注册指标描述 + // Register metric descriptions for Prometheus. metrics.register_descriptions(); metrics } - /// 注册指标描述 + /// Registers descriptions for all metrics. fn register_descriptions(&self) { if !self.enabled { return; } describe_counter!("proxy_validation_attempts_total", "Total number of proxy validation attempts"); - describe_counter!("proxy_validation_success_total", "Total number of successful proxy validations"); - describe_counter!("proxy_validation_failure_total", "Total number of failed proxy validations"); - describe_counter!( "proxy_validation_failure_by_type_total", - "Total number of failed proxy validations by error type" + "Total number of failed proxy validations categorized by error type" ); - - describe_gauge!("proxy_chain_length", "Length of proxy chains being validated"); - - describe_histogram!("proxy_validation_duration_seconds", "Duration of proxy validation in seconds"); - - describe_gauge!("proxy_cache_size", "Size of the proxy validation cache"); - - describe_counter!("proxy_cache_hits_total", "Total number of cache hits"); - - describe_counter!("proxy_cache_misses_total", "Total number of cache misses"); + describe_gauge!("proxy_chain_length", "Current length of proxy chains being validated"); + describe_histogram!("proxy_validation_duration_seconds", "Time taken to validate a proxy chain in seconds"); + describe_gauge!("proxy_cache_size", "Current number of entries in the proxy validation cache"); + describe_counter!("proxy_cache_hits_total", "Total number of cache hits for proxy validation"); + describe_counter!("proxy_cache_misses_total", "Total number of cache misses for proxy validation"); } - /// 记录验证尝试 + /// Increments the total number of validation attempts. pub fn increment_validation_attempts(&self) { if !self.enabled { return; @@ -85,7 +76,7 @@ impl ProxyMetrics { ); } - /// 记录验证成功 + /// Records a successful validation. pub fn record_validation_success(&self, from_trusted_proxy: bool, proxy_hops: usize, duration: Duration) { if !self.enabled { return; @@ -111,7 +102,7 @@ impl ProxyMetrics { ); } - /// 记录验证失败 + /// Records a failed validation with the specific error type. pub fn record_validation_failure(&self, error: &ProxyError, duration: Duration) { if !self.enabled { return; @@ -152,7 +143,7 @@ impl ProxyMetrics { ); } - /// 记录验证模式使用情况 + /// Records the validation mode currently in use. pub fn record_validation_mode(&self, mode: ValidationMode) { if !self.enabled { return; @@ -170,7 +161,7 @@ impl ProxyMetrics { ); } - /// 记录缓存指标 + /// Records cache performance metrics. pub fn record_cache_metrics(&self, hits: u64, misses: u64, size: usize) { if !self.enabled { return; @@ -181,7 +172,7 @@ impl ProxyMetrics { gauge!("proxy_cache_size", size as f64, "app" => self.app_name.clone()); } - /// 打印指标摘要 + /// Prints a summary of enabled metrics to the log. pub fn print_summary(&self) { if !self.enabled { info!("Metrics collection is disabled"); @@ -202,10 +193,10 @@ impl ProxyMetrics { } } -/// 默认应用名称 +/// Default application name for metrics. const DEFAULT_APP_NAME: &str = "trusted-proxy"; -/// 创建默认的代理指标收集器 +/// Creates a default `ProxyMetrics` collector. pub fn default_proxy_metrics(enabled: bool) -> ProxyMetrics { ProxyMetrics::new(DEFAULT_APP_NAME, enabled) } diff --git a/crates/trusted-proxies/src/proxy/validator.rs b/crates/trusted-proxies/src/proxy/validator.rs index 0dfc279e..12e4914f 100644 --- a/crates/trusted-proxies/src/proxy/validator.rs +++ b/crates/trusted-proxies/src/proxy/validator.rs @@ -12,42 +12,41 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Proxy validator for validating proxy chains and client information - -use std::net::{IpAddr, SocketAddr}; -use std::time::Instant; +//! Proxy validator for verifying proxy chains and extracting client information. use axum::http::HeaderMap; -use tracing::{debug, trace, warn}; +use std::net::{IpAddr, SocketAddr}; +use std::time::Instant; +use tracing::{debug, warn}; use crate::config::{TrustedProxyConfig, ValidationMode}; use crate::error::ProxyError; use crate::proxy::chain::ProxyChainAnalyzer; use crate::proxy::metrics::ProxyMetrics; -/// 客户端信息验证结果 +/// Information about the client extracted from the request and proxy headers. #[derive(Debug, Clone)] pub struct ClientInfo { - /// 真实客户端 IP 地址(已验证) + /// The verified real IP address of the client. pub real_ip: IpAddr, - /// 原始请求主机名(如果来自可信代理) + /// The original host requested by the client (if provided by a trusted proxy). pub forwarded_host: Option, - /// 原始请求协议(如果来自可信代理) + /// The original protocol (http/https) used by the client (if provided by a trusted proxy). pub forwarded_proto: Option, - /// 请求是否来自可信代理 + /// Whether the request was received from a trusted proxy. pub is_from_trusted_proxy: bool, - /// 直接连接的代理 IP(如果经过代理) + /// The IP address of the proxy that directly connected to this server. pub proxy_ip: Option, - /// 代理链长度 + /// The number of proxy hops identified in the chain. pub proxy_hops: usize, - /// 验证模式 + /// The validation mode used for this request. pub validation_mode: ValidationMode, - /// 验证警告信息 + /// Any warnings generated during the validation process. pub warnings: Vec, } impl ClientInfo { - /// 创建直接连接的客户端信息(无代理) + /// Creates a `ClientInfo` for a direct connection without any proxies. pub fn direct(addr: SocketAddr) -> Self { Self { real_ip: addr.ip(), @@ -61,7 +60,7 @@ impl ClientInfo { } } - /// 从可信代理创建客户端信息 + /// Creates a `ClientInfo` for a request received through a trusted proxy. pub fn from_trusted_proxy( real_ip: IpAddr, forwarded_host: Option, @@ -83,7 +82,7 @@ impl ClientInfo { } } - /// 获取客户端信息的字符串表示(用于日志) + /// Returns a string representation of the client info for logging. pub fn to_log_string(&self) -> String { format!( "client_ip={}, proxy={:?}, hops={}, trusted={}, mode={:?}", @@ -92,19 +91,19 @@ impl ClientInfo { } } -/// 代理验证器 +/// Core validator that processes incoming requests to verify proxy chains. #[derive(Debug, Clone)] pub struct ProxyValidator { - /// 代理配置 + /// Configuration for trusted proxies. config: TrustedProxyConfig, - /// 代理链分析器 + /// Analyzer for verifying the integrity of the proxy chain. chain_analyzer: ProxyChainAnalyzer, - /// 监控指标 + /// Metrics collector for observability. metrics: Option, } impl ProxyValidator { - /// 创建新的代理验证器 + /// Creates a new `ProxyValidator` with the given configuration and metrics. pub fn new(config: TrustedProxyConfig, metrics: Option) -> Self { let chain_analyzer = ProxyChainAnalyzer::new(config.clone()); @@ -115,53 +114,53 @@ impl ProxyValidator { } } - /// 验证请求并提取客户端信息 + /// Validates an incoming request and extracts client information. pub fn validate_request(&self, peer_addr: Option, headers: &HeaderMap) -> Result { let start_time = Instant::now(); - // 记录验证开始 + // Record the start of the validation attempt. self.record_metric_start(); - // 验证请求 + // Perform the internal validation logic. let result = self.validate_request_internal(peer_addr, headers); - // 记录验证结果 + // Record the result and duration. let duration = start_time.elapsed(); self.record_metric_result(&result, duration); result } - /// 内部验证逻辑 + /// Internal logic for request validation. fn validate_request_internal(&self, peer_addr: Option, headers: &HeaderMap) -> Result { - // 如果没有对端地址,使用默认值 + // Fallback to unspecified address if peer address is missing. let peer_addr = peer_addr.unwrap_or_else(|| SocketAddr::new(IpAddr::from([0, 0, 0, 0]), 0)); - // 检查是否来自可信代理 + // Check if the direct peer is a trusted proxy. if self.config.is_trusted(&peer_addr) { - debug!("Request from trusted proxy: {}", peer_addr.ip()); + debug!("Request received from trusted proxy: {}", peer_addr.ip()); - // 来自可信代理,解析转发头部 + // Parse and validate headers from the trusted proxy. self.validate_trusted_proxy_request(&peer_addr, headers) } else { - // 检查是否为私有网络地址 + // Log a warning if the request is from a private network but not trusted. if self.config.is_private_network(&peer_addr.ip()) { warn!( - "Request from private network but not trusted: {}. This might be a configuration issue.", + "Request from private network but not trusted: {}. This might indicate a configuration issue.", peer_addr.ip() ); } - // 来自不可信代理或直接连接 + // Treat as a direct connection if the peer is not trusted. Ok(ClientInfo::direct(peer_addr)) } } - /// 验证来自可信代理的请求 + /// Validates a request that originated from a trusted proxy. fn validate_trusted_proxy_request(&self, proxy_addr: &SocketAddr, headers: &HeaderMap) -> Result { let proxy_ip = proxy_addr.ip(); - // 优先使用 RFC 7239 Forwarded 头部(如果启用) + // Prefer RFC 7239 "Forwarded" header if enabled, otherwise fallback to legacy headers. let client_info = if self.config.enable_rfc7239 { self.try_parse_rfc7239_headers(headers, proxy_ip) .unwrap_or_else(|| self.parse_legacy_headers(headers, proxy_ip)) @@ -169,28 +168,21 @@ impl ProxyValidator { self.parse_legacy_headers(headers, proxy_ip) }; - // 验证代理链 + // Analyze the integrity and continuity of the proxy chain. let chain_analysis = self .chain_analyzer .analyze_chain(&client_info.proxy_chain, proxy_ip, headers)?; - // 检查代理链长度 + // Enforce maximum hop limit. if chain_analysis.hops > self.config.max_hops { return Err(ProxyError::ChainTooLong(chain_analysis.hops, self.config.max_hops)); } - // 检查链连续性(如果启用) + // Enforce chain continuity if enabled. if self.config.enable_chain_continuity_check && !chain_analysis.is_continuous { return Err(ProxyError::ChainNotContinuous); } - // 创建客户端信息 - let warnings = if !chain_analysis.warnings.is_empty() { - chain_analysis.warnings - } else { - Vec::new() - }; - Ok(ClientInfo::from_trusted_proxy( chain_analysis.client_ip, client_info.forwarded_host, @@ -198,11 +190,11 @@ impl ProxyValidator { proxy_ip, chain_analysis.hops, self.config.validation_mode, - warnings, + chain_analysis.warnings, )) } - /// 尝试解析 RFC 7239 Forwarded 头部 + /// Attempts to parse the RFC 7239 "Forwarded" header. fn try_parse_rfc7239_headers(&self, headers: &HeaderMap, proxy_ip: IpAddr) -> Option { headers .get("forwarded") @@ -210,7 +202,7 @@ impl ProxyValidator { .and_then(|s| Self::parse_forwarded_header(s, proxy_ip)) } - /// 解析传统的代理头部 + /// Parses legacy proxy headers (X-Forwarded-For, X-Forwarded-Host, X-Forwarded-Proto). fn parse_legacy_headers(&self, headers: &HeaderMap, proxy_ip: IpAddr) -> ParsedHeaders { let forwarded_host = headers .get("x-forwarded-host") @@ -225,8 +217,8 @@ impl ProxyValidator { let proxy_chain = headers .get("x-forwarded-for") .and_then(|h| h.to_str().ok()) - .map(|s| Self::parse_x_forwarded_for(s)) - .unwrap_or_else(Vec::new); + .map(Self::parse_x_forwarded_for) + .unwrap_or_default(); ParsedHeaders { proxy_chain, @@ -235,16 +227,15 @@ impl ProxyValidator { } } - /// 解析 RFC 7239 Forwarded 头部 + /// Parses the RFC 7239 "Forwarded" header value. fn parse_forwarded_header(header_value: &str, proxy_ip: IpAddr) -> Option { - // 简化实现:只处理第一个值 + // Simplified implementation: processes only the first entry in the header. let first_part = header_value.split(',').next()?.trim(); let mut proxy_chain = Vec::new(); let mut forwarded_host = None; let mut forwarded_proto = None; - // 解析键值对 for part in first_part.split(';') { let part = part.trim(); if let Some((key, value)) = part.split_once('=') { @@ -253,7 +244,7 @@ impl ProxyValidator { match key.as_str() { "for" => { - // 解析客户端 IP(可能包含端口) + // Extract IP address, ignoring port if present. if let Some(ip_part) = value.split(':').next() { if let Ok(ip) = ip_part.parse::() { proxy_chain.push(ip); @@ -271,7 +262,7 @@ impl ProxyValidator { } } - // 如果没有找到客户端 IP,添加代理 IP 作为备选 + // Fallback to the proxy IP if no client IP was found in the header. if proxy_chain.is_empty() { proxy_chain.push(proxy_ip); } @@ -283,28 +274,28 @@ impl ProxyValidator { }) } - /// 解析 X-Forwarded-For 头部 + /// Parses the X-Forwarded-For header into a list of IP addresses. fn parse_x_forwarded_for(header_value: &str) -> Vec { header_value .split(',') .map(|s| s.trim()) .filter(|s| !s.is_empty()) .filter_map(|s| { - // 移除端口部分(如果存在) + // Strip port if present. let ip_part = s.split(':').next().unwrap_or(s); ip_part.parse::().ok() }) .collect() } - /// 记录验证开始指标 + /// Records the start of a validation attempt in metrics. fn record_metric_start(&self) { if let Some(metrics) = &self.metrics { metrics.increment_validation_attempts(); } } - /// 记录验证结果指标 + /// Records the result of a validation attempt in metrics. fn record_metric_result(&self, result: &Result, duration: std::time::Duration) { if let Some(metrics) = &self.metrics { match result { @@ -314,7 +305,6 @@ impl ProxyValidator { Err(err) => { metrics.record_validation_failure(err, duration); - // 记录失败的验证(如果启用) if self.config.log_failed_validations { warn!("Proxy validation failed: {}", err); } @@ -324,13 +314,13 @@ impl ProxyValidator { } } -/// 解析后的头部信息 +/// Internal structure for holding parsed header information. #[derive(Debug, Clone)] struct ParsedHeaders { - /// 代理链(客户端 IP 在第一个位置) + /// The chain of proxy IPs (client IP is typically the first). proxy_chain: Vec, - /// 转发的主机名 + /// The original host requested. forwarded_host: Option, - /// 转发的协议 + /// The original protocol used. forwarded_proto: Option, } diff --git a/crates/trusted-proxies/src/state.rs b/crates/trusted-proxies/src/state.rs index d5a94e18..85a551e1 100644 --- a/crates/trusted-proxies/src/state.rs +++ b/crates/trusted-proxies/src/state.rs @@ -12,14 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Shared application state. + use crate::{AppConfig, ProxyMetrics}; use std::sync::Arc; -/// 应用状态 +/// Global application state shared across handlers and middleware. #[derive(Clone)] pub struct AppState { - /// 应用配置 + /// Immutable application configuration. pub config: Arc, - /// 代理指标收集器 + /// Optional metrics collector for observability. pub metrics: Option, } diff --git a/crates/trusted-proxies/src/utils/ip.rs b/crates/trusted-proxies/src/utils/ip.rs index 2af0a06f..1d298718 100644 --- a/crates/trusted-proxies/src/utils/ip.rs +++ b/crates/trusted-proxies/src/utils/ip.rs @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! IP address utility functions +//! IP address utility functions for validation and classification. use ipnetwork::IpNetwork; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str::FromStr; -/// IP 工具函数集合 +/// Collection of IP-related utility functions. pub struct IpUtils; impl IpUtils { - /// 检查 IP 地址是否有效 + /// Checks if an IP address is valid for general use (not unspecified, multicast, or reserved). pub fn is_valid_ip_address(ip: &IpAddr) -> bool { !ip.is_unspecified() && !ip.is_multicast() && !Self::is_reserved_ip(ip) } - /// 检查 IP 是否为保留地址 + /// Checks if an IP address belongs to a reserved range. pub fn is_reserved_ip(ip: &IpAddr) -> bool { match ip { IpAddr::V4(ipv4) => Self::is_reserved_ipv4(ipv4), @@ -35,11 +35,11 @@ impl IpUtils { } } - /// 检查 IPv4 是否为保留地址 + /// Checks if an IPv4 address belongs to a reserved range. pub fn is_reserved_ipv4(ip: &Ipv4Addr) -> bool { let octets = ip.octets(); - // 检查常见的保留地址范围 + // Check common reserved IPv4 ranges matches!( octets, [0, _, _, _] | // 0.0.0.0/8 @@ -60,11 +60,11 @@ impl IpUtils { ) } - /// 检查 IPv6 是否为保留地址 + /// Checks if an IPv6 address belongs to a reserved range. pub fn is_reserved_ipv6(ip: &Ipv6Addr) -> bool { let segments = ip.segments(); - // 检查常见的保留地址范围 + // Check common reserved IPv6 ranges matches!( segments, [0, 0, 0, 0, 0, 0, 0, 0] | // ::/128 @@ -76,7 +76,7 @@ impl IpUtils { ) } - /// 检查 IP 是否为私有地址 + /// Checks if an IP address is a private address. pub fn is_private_ip(ip: &IpAddr) -> bool { match ip { IpAddr::V4(ipv4) => Self::is_private_ipv4(ipv4), @@ -84,7 +84,7 @@ impl IpUtils { } } - /// 检查 IPv4 是否为私有地址 + /// Checks if an IPv4 address is a private address. pub fn is_private_ipv4(ip: &Ipv4Addr) -> bool { let octets = ip.octets(); @@ -96,7 +96,7 @@ impl IpUtils { ) } - /// 检查 IPv6 是否为私有地址 + /// Checks if an IPv6 address is a private address. pub fn is_private_ipv6(ip: &Ipv6Addr) -> bool { let segments = ip.segments(); @@ -106,7 +106,7 @@ impl IpUtils { ) } - /// 检查 IP 是否为回环地址 + /// Checks if an IP address is a loopback address. pub fn is_loopback_ip(ip: &IpAddr) -> bool { match ip { IpAddr::V4(ipv4) => ipv4.is_loopback(), @@ -114,7 +114,7 @@ impl IpUtils { } } - /// 检查 IP 是否为链路本地地址 + /// Checks if an IP address is a link-local address. pub fn is_link_local_ip(ip: &IpAddr) -> bool { match ip { IpAddr::V4(ipv4) => ipv4.is_link_local(), @@ -122,7 +122,7 @@ impl IpUtils { } } - /// 检查 IP 是否为文档地址(TEST-NET) + /// Checks if an IP address is a documentation address (TEST-NET). pub fn is_documentation_ip(ip: &IpAddr) -> bool { match ip { IpAddr::V4(ipv4) => { @@ -141,12 +141,12 @@ impl IpUtils { } } - /// 从字符串解析 IP 地址,支持 CIDR 表示法 + /// Parses an IP address or CIDR range from a string. pub fn parse_ip_or_cidr(s: &str) -> Result { IpNetwork::from_str(s).map_err(|e| format!("Failed to parse IP/CIDR '{}': {}", s, e)) } - /// 从逗号分隔的字符串解析 IP 列表 + /// Parses a comma-separated list of IP addresses. pub fn parse_ip_list(s: &str) -> Result, String> { let mut ips = Vec::new(); @@ -165,7 +165,7 @@ impl IpUtils { Ok(ips) } - /// 从逗号分隔的字符串解析网络列表 + /// Parses a comma-separated list of IP networks (CIDR). pub fn parse_network_list(s: &str) -> Result, String> { let mut networks = Vec::new(); @@ -184,12 +184,12 @@ impl IpUtils { Ok(networks) } - /// 检查 IP 是否在给定的网络列表中 + /// Checks if an IP address is contained within any of the given networks. pub fn ip_in_networks(ip: &IpAddr, networks: &[IpNetwork]) -> bool { networks.iter().any(|network| network.contains(*ip)) } - /// 获取 IP 地址的类型描述 + /// Returns a string description of the IP address type. pub fn get_ip_type(ip: &IpAddr) -> &'static str { if Self::is_private_ip(ip) { "private" @@ -206,16 +206,16 @@ impl IpUtils { } } - /// 将 IP 地址转换为规范形式 + /// Returns the canonical string representation of an IP address. pub fn canonical_ip(ip: &IpAddr) -> String { match ip { IpAddr::V4(ipv4) => ipv4.to_string(), IpAddr::V6(ipv6) => { - // 压缩 IPv6 地址 + // Compress IPv6 address let mut result = String::new(); let segments = ipv6.segments(); - // 查找最长的连续零段 + // Find the longest sequence of zero segments let mut longest_start = 0; let mut longest_len = 0; let mut current_start = 0; @@ -241,20 +241,21 @@ impl IpUtils { longest_len = current_len; } - // 格式化为字符串 - for mut i in 0..8 { + // Format as string + let mut i = 0; + while i < 8 { if i == longest_start && longest_len > 1 { result.push_str("::"); - i += longest_len - 1; - } else if i == longest_start && longest_len == 1 { - result.push('0'); + i += longest_len; + if i == 8 { + break; + } } else { - if i > 0 && i != longest_start { + if i > 0 && (i != longest_start + longest_len || longest_len <= 1) { result.push(':'); } - if segments[i] != 0 || (i == 7 && result.is_empty()) { - result.push_str(&format!("{:x}", segments[i])); - } + result.push_str(&format!("{:x}", segments[i])); + i += 1; } } @@ -263,3 +264,8 @@ impl IpUtils { } } } + +/// Helper function to check if an IP address is valid. +pub fn is_valid_ip_address(ip: &IpAddr) -> bool { + IpUtils::is_valid_ip_address(ip) +} diff --git a/crates/trusted-proxies/src/utils/mod.rs b/crates/trusted-proxies/src/utils/mod.rs index e40c4668..9b46f613 100644 --- a/crates/trusted-proxies/src/utils/mod.rs +++ b/crates/trusted-proxies/src/utils/mod.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Utility functions and helpers +//! Utility functions and helpers for the trusted proxy system. mod ip; mod validation; @@ -20,32 +20,32 @@ mod validation; pub use ip::*; pub use validation::*; -/// 工具函数集合 +/// Collection of general utility functions. #[derive(Debug, Clone)] pub struct Utils; impl Utils { - /// 生成追踪 ID + /// Generates a unique trace ID. pub fn generate_trace_id() -> String { format!("trace-{}", uuid::Uuid::new_v4()) } - /// 生成 Span ID + /// Generates a unique span ID. pub fn generate_span_id() -> String { format!("span-{}", uuid::Uuid::new_v4()) } - /// 安全的将字符串转换为 usize + /// Safely parses a string into a `usize`, returning a default value on failure. pub fn safe_parse_usize(s: &str, default: usize) -> usize { s.parse().unwrap_or(default) } - /// 安全的将字符串转换为 u64 + /// Safely parses a string into a `u64`, returning a default value on failure. pub fn safe_parse_u64(s: &str, default: u64) -> u64 { s.parse().unwrap_or(default) } - /// 安全的将字符串转换为布尔值 + /// Safely parses a string into a boolean, returning a default value on failure. pub fn safe_parse_bool(s: &str, default: bool) -> bool { match s.to_lowercase().as_str() { "true" | "1" | "yes" | "on" => true, @@ -54,7 +54,7 @@ impl Utils { } } - /// 格式化持续时间 + /// Formats a `Duration` into a human-readable string. pub fn format_duration(duration: std::time::Duration) -> String { if duration.as_secs() > 0 { format!("{:.2}s", duration.as_secs_f64()) @@ -67,22 +67,22 @@ impl Utils { } } - /// 获取当前时间戳 + /// Returns the current UTC timestamp in RFC 3339 format. pub fn current_timestamp() -> String { chrono::Utc::now().to_rfc3339() } - /// 安全的获取环境变量 + /// Safely retrieves an environment variable. pub fn get_env_var(key: &str) -> Option { std::env::var(key).ok() } - /// 获取环境变量,如果不存在则使用默认值 + /// Retrieves an environment variable or returns a default value if not set. pub fn get_env_var_or(key: &str, default: &str) -> String { std::env::var(key).unwrap_or_else(|_| default.to_string()) } - /// 检查环境变量是否存在 + /// Checks if an environment variable is set. pub fn has_env_var(key: &str) -> bool { std::env::var(key).is_ok() } diff --git a/crates/trusted-proxies/src/utils/validation.rs b/crates/trusted-proxies/src/utils/validation.rs index e4bdbf43..48ca6746 100644 --- a/crates/trusted-proxies/src/utils/validation.rs +++ b/crates/trusted-proxies/src/utils/validation.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Validation utility functions +//! Validation utility functions for various data types. use http::HeaderMap; use lazy_static::lazy_static; @@ -20,11 +20,11 @@ use regex::Regex; use std::net::IpAddr; use std::str::FromStr; -/// 验证工具函数集合 +/// Collection of validation utility functions. pub struct ValidationUtils; impl ValidationUtils { - /// 验证电子邮件地址 + /// Validates an email address format. pub fn is_valid_email(email: &str) -> bool { lazy_static! { static ref EMAIL_REGEX: Regex = @@ -34,7 +34,7 @@ impl ValidationUtils { EMAIL_REGEX.is_match(email) } - /// 验证 URL + /// Validates a URL format. pub fn is_valid_url(url: &str) -> bool { lazy_static! { static ref URL_REGEX: Regex = @@ -44,22 +44,19 @@ impl ValidationUtils { URL_REGEX.is_match(url) } - /// 验证 X-Forwarded-For 头部 + /// Validates the format of an X-Forwarded-For header value. pub fn validate_x_forwarded_for(header_value: &str) -> bool { if header_value.is_empty() { return false; } - // 分割 IP 地址 let ips: Vec<&str> = header_value.split(',').map(|s| s.trim()).collect(); - // 检查每个 IP 地址 for ip_str in ips { if ip_str.is_empty() { return false; } - // 移除端口部分(如果存在) let ip_part = ip_str.split(':').next().unwrap_or(ip_str); if IpAddr::from_str(ip_part).is_err() { @@ -70,20 +67,18 @@ impl ValidationUtils { true } - /// 验证 Forwarded 头部(RFC 7239) + /// Validates the format of an RFC 7239 Forwarded header value. pub fn validate_forwarded_header(header_value: &str) -> bool { if header_value.is_empty() { return false; } - // 简化的验证:检查基本格式 let parts: Vec<&str> = header_value.split(';').collect(); if parts.is_empty() { return false; } - // 检查每个部分是否包含等号 for part in parts { let part = part.trim(); if !part.contains('=') { @@ -94,7 +89,7 @@ impl ValidationUtils { true } - /// 验证 IP 地址是否在允许的范围内 + /// Checks if an IP address is within any of the specified CIDR ranges. pub fn validate_ip_in_range(ip: &IpAddr, cidr_ranges: &[String]) -> bool { for cidr in cidr_ranges { if let Ok(network) = ipnetwork::IpNetwork::from_str(cidr) { @@ -107,16 +102,14 @@ impl ValidationUtils { false } - /// 验证头部是否包含恶意内容 + /// Validates a header value for security (length and control characters). pub fn validate_header_value(value: &str) -> bool { - // 检查是否包含控制字符(除了水平制表符) for c in value.chars() { if c.is_control() && c != '\t' && c != '\n' && c != '\r' { return false; } } - // 检查长度限制(防止头部过大攻击) if value.len() > 8192 { return false; } @@ -124,53 +117,47 @@ impl ValidationUtils { true } - /// 验证整个头部映射 + /// Validates an entire HeaderMap for security. pub fn validate_headers(headers: &HeaderMap) -> bool { for (name, value) in headers { - // 检查头部名称 let name_str = name.as_str(); if name_str.len() > 256 { return false; } - // 检查头部值 if let Ok(value_str) = value.to_str() { if !Self::validate_header_value(value_str) { return false; } - } else { - // 无法转换为字符串,可能包含二进制数据 - if value.len() > 8192 { - return false; - } + } else if value.len() > 8192 { + return false; } } true } - /// 验证端口号 + /// Validates a port number. pub fn validate_port(port: u16) -> bool { - port > 0 && port <= 65535 + port > 0 } - /// 验证 CIDR 表示法 + /// Validates a CIDR notation string. pub fn validate_cidr(cidr: &str) -> bool { ipnetwork::IpNetwork::from_str(cidr).is_ok() } - /// 验证代理链长度 + /// Validates the length of a proxy chain. pub fn validate_proxy_chain_length(chain: &[IpAddr], max_length: usize) -> bool { chain.len() <= max_length } - /// 验证代理链是否连续 + /// Validates that a proxy chain does not contain duplicate adjacent IPs. pub fn validate_proxy_chain_continuity(chain: &[IpAddr]) -> bool { if chain.len() < 2 { return true; } - // 检查是否有重复的相邻 IP for i in 1..chain.len() { if chain[i] == chain[i - 1] { return false; @@ -180,24 +167,25 @@ impl ValidationUtils { true } - /// 验证字符串是否只包含安全字符 + /// Checks if a string contains only safe characters for use in URLs or headers. pub fn is_safe_string(s: &str) -> bool { - // 允许的字符:字母、数字、基本标点符号 - let safe_pattern = Regex::new(r"^[a-zA-Z0-9\-._~:/?#\[\]@!$&'()*+,;=]+$").unwrap(); - safe_pattern.is_match(s) + lazy_static! { + static ref SAFE_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9\-._~:/?#\[\]@!$&'()*+,;=]+$").unwrap(); + } + SAFE_REGEX.is_match(s) } - /// 验证速率限制参数 + /// Validates rate limiting parameters. pub fn validate_rate_limit_params(requests: u32, period_seconds: u64) -> bool { requests > 0 && requests <= 10000 && period_seconds > 0 && period_seconds <= 86400 } - /// 验证缓存参数 + /// Validates cache configuration parameters. pub fn validate_cache_params(capacity: usize, ttl_seconds: u64) -> bool { capacity > 0 && capacity <= 1000000 && ttl_seconds > 0 && ttl_seconds <= 86400 } - /// 脱敏敏感数据 + /// Redacts sensitive information from a string based on provided patterns. pub fn mask_sensitive_data(data: &str, sensitive_patterns: &[&str]) -> String { let mut result = data.to_string(); diff --git a/crates/trusted-proxies/test/integration/api_tests.rs b/crates/trusted-proxies/test/integration/api_tests.rs deleted file mode 100644 index 6cac9353..00000000 --- a/crates/trusted-proxies/test/integration/api_tests.rs +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2024 RustFS Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! API integration tests - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use axum::body::Body; - use axum::{extract::State, routing::get, Router}; - use serde_json::{json, Value}; - use tower::ServiceExt; - - use crate::config::{AppConfig, TrustedProxy, TrustedProxyConfig, ValidationMode}; - use crate::middleware::TrustedProxyLayer; - use crate::AppState; - - fn create_test_app_state() -> AppState { - let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; - - let proxy_config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]); - - let config = AppConfig::new( - proxy_config, - crate::config::CacheConfig::default(), - crate::config::MonitoringConfig::default(), - crate::config::CloudConfig::default(), - "127.0.0.1:3000".parse().unwrap(), - ); - - AppState { - config: Arc::new(config), - metrics: None, - } - } - - fn create_test_api_router() -> Router { - let state = create_test_app_state(); - let proxy_layer = TrustedProxyLayer::enabled(state.config.proxy.clone(), None); - - Router::new() - .route("/health", get(health_check)) - .route("/config", get(show_config)) - .with_state(state) - .layer(proxy_layer) - } - - async fn health_check() -> axum::response::Json { - axum::response::Json(json!({ - "status": "healthy", - "service": "trusted-proxy-test" - })) - } - - async fn show_config(State(state): State) -> axum::response::Json { - axum::response::Json(json!({ - "server": state.config.server_addr.to_string(), - "proxy": { - "trusted_networks": state.config.proxy.proxies.len(), - } - })) - } - - #[tokio::test] - async fn test_health_check_endpoint() { - let app = create_test_api_router(); - - let request = axum::http::Request::builder().uri("/health").body(Body::empty()).unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let json: Value = serde_json::from_slice(&body).unwrap(); - - assert_eq!(json["status"], "healthy"); - assert_eq!(json["service"], "trusted-proxy-test"); - } - - #[tokio::test] - async fn test_config_endpoint() { - let app = create_test_api_router(); - - let request = axum::http::Request::builder().uri("/config").body(Body::empty()).unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let json: Value = serde_json::from_slice(&body).unwrap(); - - assert_eq!(json["server"], "127.0.0.1:3000"); - assert_eq!(json["proxy"]["trusted_networks"], 1); - } - - #[tokio::test] - async fn test_proxy_headers_in_api() { - let state = create_test_app_state(); - let proxy_layer = TrustedProxyLayer::enabled(state.config.proxy.clone(), None); - - let app = Router::new() - .route( - "/client-test", - get(|req: axum::extract::Request| async move { - let client_info = req.extensions().get::(); - match client_info { - Some(info) => axum::response::Json(json!({ - "client_ip": info.real_ip.to_string(), - "trusted": info.is_from_trusted_proxy - })), - None => axum::response::Json(json!({ - "error": "No client info" - })), - } - }), - ) - .with_state(state) - .layer(proxy_layer); - - // 测试带代理头部的请求 - let request = axum::http::Request::builder() - .uri("/client-test") - .header("X-Forwarded-For", "203.0.113.195") - .body(Body::empty()) - .unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let json: Value = serde_json::from_slice(&body).unwrap(); - - // 由于请求来自 127.0.0.1(可信代理),应该解析 X-Forwarded-For - if json.get("client_ip").is_some() { - let client_ip = json["client_ip"].as_str().unwrap(); - // 可能是 203.0.113.195 或 127.0.0.1,取决于中间件如何配置 - assert!(client_ip == "203.0.113.195" || client_ip == "127.0.0.1"); - } - } - - #[tokio::test] - async fn test_missing_endpoint() { - let app = create_test_api_router(); - - let request = axum::http::Request::builder().uri("/not-found").body(Body::empty()).unwrap(); - - let response = app.oneshot(request).await.unwrap(); - - // 应该返回 404 - assert_eq!(response.status(), 404); - } - - #[tokio::test] - async fn test_request_without_proxy_layer() { - // 创建没有代理中间件的路由 - let app = Router::new().route("/simple", get(|| async { "OK" })); - - let request = axum::http::Request::builder().uri("/simple").body(Body::empty()).unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - assert_eq!(String::from_utf8(body.to_vec()).unwrap(), "OK"); - } -} diff --git a/crates/trusted-proxies/test/integration/cloud_tests.rs b/crates/trusted-proxies/test/integration/cloud_tests.rs deleted file mode 100644 index 4c8d69f8..00000000 --- a/crates/trusted-proxies/test/integration/cloud_tests.rs +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright 2024 RustFS Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Cloud metadata integration tests - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use wiremock::matchers::{method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; - - use crate::cloud::detector::CloudDetector; - use crate::cloud::metadata::{AwsMetadataFetcher, AzureMetadataFetcher, GcpMetadataFetcher}; - use crate::cloud::ranges::{CloudflareIpRanges, GoogleCloudIpRanges}; - - #[tokio::test] - async fn test_cloud_detector_disabled() { - let detector = CloudDetector::new( - false, // 禁用 - std::time::Duration::from_secs(1), - None, - ); - - let provider = detector.detect_provider(); - assert!(provider.is_none()); - - let ranges = detector.fetch_trusted_ranges().await; - assert!(ranges.is_ok()); - assert!(ranges.unwrap().is_empty()); - } - - #[tokio::test] - async fn test_cloud_detector_forced_provider() { - let detector = CloudDetector::new(true, std::time::Duration::from_secs(1), Some("aws".to_string())); - - let provider = detector.detect_provider(); - assert!(provider.is_some()); - assert_eq!(provider.unwrap().name(), "aws"); - } - - #[tokio::test] - async fn test_aws_metadata_fetcher() { - let fetcher = AwsMetadataFetcher::new(); - - // 测试提供者名称 - assert_eq!(fetcher.provider_name(), "aws"); - - // 由于不在 AWS 环境中,这些调用应该失败或返回默认值 - let network_result = fetcher.fetch_network_cidrs().await; - // 可能返回默认范围或错误 - assert!(network_result.is_ok()); - - let public_result = fetcher.fetch_public_ip_ranges().await; - // 可能从 API 获取或返回空列表 - assert!(public_result.is_ok()); - } - - #[tokio::test] - async fn test_azure_metadata_fetcher() { - let fetcher = AzureMetadataFetcher::new(); - - // 测试提供者名称 - assert_eq!(fetcher.provider_name(), "azure"); - - // 由于不在 Azure 环境中,这些调用应该返回默认值 - let network_result = fetcher.fetch_network_cidrs().await; - assert!(network_result.is_ok()); - - let public_result = fetcher.fetch_public_ip_ranges().await; - assert!(public_result.is_ok()); - } - - #[tokio::test] - async fn test_gcp_metadata_fetcher() { - let fetcher = GcpMetadataFetcher::new(); - - // 测试提供者名称 - assert_eq!(fetcher.provider_name(), "gcp"); - - // 由于不在 GCP 环境中,这些调用应该返回默认值 - let network_result = fetcher.fetch_network_cidrs().await; - assert!(network_result.is_ok()); - - let public_result = fetcher.fetch_public_ip_ranges().await; - assert!(public_result.is_ok()); - } - - #[tokio::test] - async fn test_cloudflare_ip_ranges_static() { - let ranges = CloudflareIpRanges::fetch().await; - - assert!(ranges.is_ok()); - - let networks = ranges.unwrap(); - assert!(!networks.is_empty()); - - // 检查是否包含预期的范围 - let has_ipv4 = networks - .iter() - .any(|n| n.to_string().contains("103.21.244.0/22") || n.to_string().contains("198.41.128.0/17")); - - let has_ipv6 = networks - .iter() - .any(|n| n.to_string().contains("2400:cb00::/32") || n.to_string().contains("2606:4700::/32")); - - assert!(has_ipv4 || has_ipv6); - } - - #[tokio::test] - async fn test_google_cloud_ip_ranges_api_mock() { - // 创建模拟服务器 - let mock_server = MockServer::start().await; - - // 模拟 Google IP 范围 API 响应 - let mock_response = r#" - { - "prefixes": [ - {"ipv4Prefix": "8.34.208.0/20"}, - {"ipv4Prefix": "8.35.192.0/20"}, - {"ipv6Prefix": "2001:4860::/32"} - ] - } - "#; - - Mock::given(method("GET")) - .and(path("/ipranges/cloud.json")) - .respond_with(ResponseTemplate::new(200).set_body_string(mock_response)) - .mount(&mock_server) - .await; - - // 创建自定义客户端指向模拟服务器 - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(2)) - .build() - .unwrap(); - - let url = format!("{}/ipranges/cloud.json", mock_server.uri()); - - let response = client.get(&url).send().await.unwrap(); - assert_eq!(response.status(), 200); - - let body = response.text().await.unwrap(); - assert!(body.contains("8.34.208.0/20")); - assert!(body.contains("2001:4860::/32")); - } - - #[tokio::test] - async fn test_cloud_detector_try_all_providers() { - let detector = CloudDetector::new(true, std::time::Duration::from_secs(2), None); - - // 在测试环境中,所有提供者都应该失败或返回空列表 - let result = detector.try_all_providers().await; - - // 应该成功返回(即使是空列表) - assert!(result.is_ok()); - } - - #[test] - fn test_ip_network_parsing() { - // 测试 CIDR 解析 - let cidr = ipnetwork::IpNetwork::from_str("192.168.1.0/24"); - assert!(cidr.is_ok()); - - let network = cidr.unwrap(); - assert_eq!(network.prefix(), 24); - - // 测试 IP 包含检查 - let ip: std::net::IpAddr = "192.168.1.100".parse().unwrap(); - assert!(network.contains(ip)); - - let ip_outside: std::net::IpAddr = "192.168.2.100".parse().unwrap(); - assert!(!network.contains(ip_outside)); - - // 测试 IPv6 CIDR - let ipv6_cidr = ipnetwork::IpNetwork::from_str("2001:db8::/32"); - assert!(ipv6_cidr.is_ok()); - - let ipv6_network = ipv6_cidr.unwrap(); - assert_eq!(ipv6_network.prefix(), 32); - } -} diff --git a/crates/trusted-proxies/test/integration/proxy_tests.rs b/crates/trusted-proxies/test/integration/proxy_tests.rs deleted file mode 100644 index 381e2a10..00000000 --- a/crates/trusted-proxies/test/integration/proxy_tests.rs +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2024 RustFS Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Proxy system integration tests - -#[cfg(test)] -mod tests { - use std::error::Request; - use axum::body::Body; - use axum::{extract::Request, routing::get, Router}; - use tower::ServiceExt; - - use crate::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode}; - use crate::middleware::{ClientInfo, TrustedProxyLayer}; - - fn create_test_router() -> Router { - let proxies = vec![ - TrustedProxy::Single("127.0.0.1".parse().unwrap()), - TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), - ]; - - let config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]); - - let proxy_layer = TrustedProxyLayer::enabled(config, None); - - Router::new() - .route( - "/test", - get(|req: Request| async move { - let client_info = req.extensions().get::(); - match client_info { - Some(info) => { - format!("IP: {}, Trusted: {}, Hops: {}", info.real_ip, info.is_from_trusted_proxy, info.proxy_hops) - } - None => "No client info".to_string(), - } - }), - ) - .layer(proxy_layer) - } - - #[tokio::test] - async fn test_direct_connection() { - let app = create_test_router(); - - // 模拟直接连接(无代理头部) - let request = axum::http::Request::builder().uri("/test").body(Body::empty()).unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let body_str = String::from_utf8(body.to_vec()).unwrap(); - - // 应该显示直接连接的 IP(在测试环境中可能是 0.0.0.0) - assert!(body_str.contains("IP:")); - } - - #[tokio::test] - async fn test_trusted_proxy_with_xff() { - let app = create_test_router(); - - // 模拟来自可信代理的请求 - let request = axum::http::Request::builder() - .uri("/test") - .header("X-Forwarded-For", "203.0.113.195, 10.0.1.100") - .body(Body::empty()) - .unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let body_str = String::from_utf8(body.to_vec()).unwrap(); - - // 应该显示客户端 IP (203.0.113.195) - assert!(body_str.contains("203.0.113.195")); - assert!(body_str.contains("Trusted: true")); - } - - #[tokio::test] - async fn test_untrusted_proxy_with_xff() { - let app = create_test_router(); - - // 模拟来自不可信代理的请求 - let request = axum::http::Request::builder() - .uri("/test") - .header("X-Forwarded-For", "203.0.113.195") - .body(Body::empty()) - .unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let body_str = String::from_utf8(body.to_vec()).unwrap(); - - // 由于请求不是来自可信代理,X-Forwarded-For 应该被忽略 - // 应该显示直接连接的 IP - assert!(!body_str.contains("203.0.113.195")); - } - - #[tokio::test] - async fn test_proxy_chain_too_long() { - let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; - - let config = TrustedProxyConfig::new( - proxies, - ValidationMode::Strict, - true, - 3, // 最大 3 跳 - true, - vec![], - ); - - let proxy_layer = TrustedProxyLayer::enabled(config, None); - - let app = Router::new().route("/test", get(|| async { "OK" })).layer(proxy_layer); - - // 模拟超长代理链 - let xff_value = (0..5).map(|i| format!("10.0.{}.1", i)).collect::>().join(", "); - - let request = axum::http::Request::builder() - .uri("/test") - .header("X-Forwarded-For", xff_value) - .body(Body::empty()) - .unwrap(); - - let response = app.oneshot(request).await.unwrap(); - - // 由于代理链太长,验证应该失败 - // 注意:中间件可能会降级处理,而不是直接拒绝 - assert_eq!(response.status(), 200); - } - - #[tokio::test] - async fn test_rfc7239_forwarded_header() { - let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; - - let config = TrustedProxyConfig::new( - proxies, - ValidationMode::HopByHop, - true, // 启用 RFC 7239 - 10, - true, - vec![], - ); - - let proxy_layer = TrustedProxyLayer::enabled(config, None); - - let app = Router::new() - .route( - "/test", - get(|req: Request| async move { - let client_info = req.extensions().get::().unwrap(); - format!("IP: {}", client_info.real_ip) - }), - ) - .layer(proxy_layer); - - // 模拟使用 RFC 7239 Forwarded 头部的请求 - let request = axum::http::Request::builder() - .uri("/test") - .header("Forwarded", r#"for=192.0.2.60;proto=https;by=203.0.113.43"#) - .body(Body::empty()) - .unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), 200); - - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let body_str = String::from_utf8(body.to_vec()).unwrap(); - - // 应该解析 RFC 7239 头部 - assert!(body_str.contains("192.0.2.60")); - } -} diff --git a/crates/trusted-proxies/test/unit/config_tests.rs b/crates/trusted-proxies/test/unit/config_tests.rs deleted file mode 100644 index 61ce51d7..00000000 --- a/crates/trusted-proxies/test/unit/config_tests.rs +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2024 RustFS Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Configuration module unit tests - -#[cfg(test)] -mod tests { - use std::net::IpAddr; - use std::str::FromStr; - - use crate::config::env::{DEFAULT_TRUSTED_PROXIES, ENV_TRUSTED_PROXIES}; - use crate::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode}; - - #[test] - fn test_config_loader_default() { - // 清理环境变量 - std::env::remove_var(ENV_TRUSTED_PROXIES); - - let config = ConfigLoader::from_env_or_default(); - - // 验证默认值 - assert_eq!(config.server_addr.port(), 3000); - assert!(!config.proxy.proxies.is_empty()); - assert_eq!(config.proxy.validation_mode, ValidationMode::HopByHop); - assert!(config.proxy.enable_rfc7239); - assert_eq!(config.proxy.max_hops, 10); - } - - #[test] - fn test_config_loader_env_vars() { - // 设置环境变量 - std::env::set_var(ENV_TRUSTED_PROXIES, "192.168.1.0/24,10.0.0.0/8"); - std::env::set_var("TRUSTED_PROXY_VALIDATION_MODE", "strict"); - std::env::set_var("TRUSTED_PROXY_MAX_HOPS", "5"); - std::env::set_var("SERVER_PORT", "8080"); - - let config = ConfigLoader::from_env(); - - if let Ok(config) = config { - assert_eq!(config.server_addr.port(), 8080); - assert_eq!(config.proxy.validation_mode, ValidationMode::Strict); - assert_eq!(config.proxy.max_hops, 5); - - // 清理环境变量 - std::env::remove_var(ENV_TRUSTED_PROXIES); - std::env::remove_var("TRUSTED_PROXY_VALIDATION_MODE"); - std::env::remove_var("TRUSTED_PROXY_MAX_HOPS"); - std::env::remove_var("SERVER_PORT"); - } else { - panic!("Failed to load config from env"); - } - } - - #[test] - fn test_trusted_proxy_config() { - let proxies = vec![ - TrustedProxy::Single("192.168.1.1".parse().unwrap()), - TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), - ]; - - let config = TrustedProxyConfig::new(proxies.clone(), ValidationMode::Strict, true, 10, true, vec![]); - - assert_eq!(config.proxies.len(), 2); - assert_eq!(config.validation_mode, ValidationMode::Strict); - assert!(config.enable_rfc7239); - assert_eq!(config.max_hops, 10); - assert!(config.enable_chain_continuity_check); - - // 测试 IP 检查 - let test_ip: IpAddr = "192.168.1.1".parse().unwrap(); - let test_socket_addr = std::net::SocketAddr::new(test_ip, 8080); - - assert!(config.is_trusted(&test_socket_addr)); - - let test_ip2: IpAddr = "10.0.1.1".parse().unwrap(); - let test_socket_addr2 = std::net::SocketAddr::new(test_ip2, 8080); - - assert!(config.is_trusted(&test_socket_addr2)); - } - - #[test] - fn test_validation_mode_from_str() { - assert_eq!(ValidationMode::from_str("lenient").unwrap(), ValidationMode::Lenient); - assert_eq!(ValidationMode::from_str("strict").unwrap(), ValidationMode::Strict); - assert_eq!(ValidationMode::from_str("hop_by_hop").unwrap(), ValidationMode::HopByHop); - - // 测试无效值 - assert!(ValidationMode::from_str("invalid").is_err()); - } - - #[test] - fn test_trusted_proxy_contains() { - // 测试单个 IP - let single_proxy = TrustedProxy::Single("192.168.1.1".parse().unwrap()); - let test_ip: IpAddr = "192.168.1.1".parse().unwrap(); - let test_ip2: IpAddr = "192.168.1.2".parse().unwrap(); - - assert!(single_proxy.contains(&test_ip)); - assert!(!single_proxy.contains(&test_ip2)); - - // 测试 CIDR 范围 - let cidr_proxy = TrustedProxy::Cidr("192.168.1.0/24".parse().unwrap()); - - assert!(cidr_proxy.contains(&test_ip)); - assert!(cidr_proxy.contains(&test_ip2)); - - let test_ip3: IpAddr = "192.168.2.1".parse().unwrap(); - assert!(!cidr_proxy.contains(&test_ip3)); - } - - #[test] - fn test_private_network_check() { - let config = TrustedProxyConfig::new( - Vec::new(), - ValidationMode::Lenient, - true, - 10, - true, - vec!["10.0.0.0/8".parse().unwrap(), "192.168.0.0/16".parse().unwrap()], - ); - - let private_ip: IpAddr = "10.0.1.1".parse().unwrap(); - let private_ip2: IpAddr = "192.168.1.1".parse().unwrap(); - let public_ip: IpAddr = "8.8.8.8".parse().unwrap(); - - assert!(config.is_private_network(&private_ip)); - assert!(config.is_private_network(&private_ip2)); - assert!(!config.is_private_network(&public_ip)); - } - - #[test] - fn test_parse_ip_list_from_env() { - use crate::config::env::parse_ip_list_from_env; - - // 测试有效的 IP 列表 - std::env::set_var("TEST_IP_LIST", "10.0.0.0/8,192.168.1.0/24"); - - let result = parse_ip_list_from_env("TEST_IP_LIST", ""); - assert!(result.is_ok()); - - let networks = result.unwrap(); - assert_eq!(networks.len(), 2); - - // 测试空值 - std::env::set_var("TEST_IP_LIST_EMPTY", ""); - let result = parse_ip_list_from_env("TEST_IP_LIST_EMPTY", ""); - assert!(result.is_ok()); - assert!(result.unwrap().is_empty()); - - // 测试无效值 - std::env::set_var("TEST_IP_LIST_INVALID", "invalid,10.0.0.0/8"); - let result = parse_ip_list_from_env("TEST_IP_LIST_INVALID", ""); - assert!(result.is_ok()); // 无效项会被跳过 - - // 清理环境变量 - std::env::remove_var("TEST_IP_LIST"); - std::env::remove_var("TEST_IP_LIST_EMPTY"); - std::env::remove_var("TEST_IP_LIST_INVALID"); - } - - #[test] - fn test_default_values() { - use crate::config::env::{ - DEFAULT_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_MAX_HOPS, DEFAULT_PROXY_VALIDATION_MODE, DEFAULT_TRUSTED_PROXIES, - }; - - assert_eq!(DEFAULT_TRUSTED_PROXIES, "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"); - assert_eq!(DEFAULT_PROXY_VALIDATION_MODE, "hop_by_hop"); - assert_eq!(DEFAULT_PROXY_MAX_HOPS, 10); - assert!(DEFAULT_PROXY_ENABLE_RFC7239); - } -} diff --git a/crates/trusted-proxies/test/unit/ip_tests.rs b/crates/trusted-proxies/test/unit/ip_tests.rs deleted file mode 100644 index 735fbd8a..00000000 --- a/crates/trusted-proxies/test/unit/ip_tests.rs +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2024 RustFS Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! IP utility tests - -#[cfg(test)] -mod tests { - use std::net::IpAddr; - use std::str::FromStr; - - use crate::utils::ip::IpUtils; - - #[test] - fn test_is_valid_ip_address() { - // 测试有效 IP - let valid_ip: IpAddr = "192.168.1.1".parse().unwrap(); - assert!(IpUtils::is_valid_ip_address(&valid_ip)); - - // 测试未指定地址 - let unspecified_ip: IpAddr = "0.0.0.0".parse().unwrap(); - assert!(!IpUtils::is_valid_ip_address(&unspecified_ip)); - - // 测试多播地址 - let multicast_ip: IpAddr = "224.0.0.1".parse().unwrap(); - assert!(!IpUtils::is_valid_ip_address(&multicast_ip)); - - // 测试 IPv6 - let valid_ipv6: IpAddr = "2001:db8::1".parse().unwrap(); - assert!(IpUtils::is_valid_ip_address(&valid_ipv6)); - - let unspecified_ipv6: IpAddr = "::".parse().unwrap(); - assert!(!IpUtils::is_valid_ip_address(&unspecified_ipv6)); - } - - #[test] - fn test_is_reserved_ip() { - // 测试私有地址 - let private_ip: IpAddr = "10.0.0.1".parse().unwrap(); - assert!(IpUtils::is_reserved_ip(&private_ip)); - - // 测试回环地址 - let loopback_ip: IpAddr = "127.0.0.1".parse().unwrap(); - assert!(IpUtils::is_reserved_ip(&loopback_ip)); - - // 测试链路本地地址 - let link_local_ip: IpAddr = "169.254.0.1".parse().unwrap(); - assert!(IpUtils::is_reserved_ip(&link_local_ip)); - - // 测试文档地址 - let documentation_ip: IpAddr = "192.0.2.1".parse().unwrap(); - assert!(IpUtils::is_reserved_ip(&documentation_ip)); - - // 测试公网地址 - let public_ip: IpAddr = "8.8.8.8".parse().unwrap(); - assert!(!IpUtils::is_reserved_ip(&public_ip)); - } - - #[test] - fn test_is_private_ip() { - // 测试 10.0.0.0/8 - assert!(IpUtils::is_private_ip(&"10.0.0.1".parse().unwrap())); - assert!(IpUtils::is_private_ip(&"10.255.255.254".parse().unwrap())); - - // 测试 172.16.0.0/12 - assert!(IpUtils::is_private_ip(&"172.16.0.1".parse().unwrap())); - assert!(IpUtils::is_private_ip(&"172.31.255.254".parse().unwrap())); - assert!(!IpUtils::is_private_ip(&"172.15.0.1".parse().unwrap())); - assert!(!IpUtils::is_private_ip(&"172.32.0.1".parse().unwrap())); - - // 测试 192.168.0.0/16 - assert!(IpUtils::is_private_ip(&"192.168.0.1".parse().unwrap())); - assert!(IpUtils::is_private_ip(&"192.168.255.254".parse().unwrap())); - - // 测试公网地址 - assert!(!IpUtils::is_private_ip(&"8.8.8.8".parse().unwrap())); - assert!(!IpUtils::is_private_ip(&"203.0.113.1".parse().unwrap())); - } - - #[test] - fn test_is_loopback_ip() { - // IPv4 回环地址 - assert!(IpUtils::is_loopback_ip(&"127.0.0.1".parse().unwrap())); - assert!(IpUtils::is_loopback_ip(&"127.255.255.254".parse().unwrap())); - - // IPv6 回环地址 - assert!(IpUtils::is_loopback_ip(&"::1".parse().unwrap())); - - // 非回环地址 - assert!(!IpUtils::is_loopback_ip(&"192.168.1.1".parse().unwrap())); - assert!(!IpUtils::is_loopback_ip(&"2001:db8::1".parse().unwrap())); - } - - #[test] - fn test_is_link_local_ip() { - // IPv4 链路本地地址 - assert!(IpUtils::is_link_local_ip(&"169.254.0.1".parse().unwrap())); - assert!(IpUtils::is_link_local_ip(&"169.254.255.254".parse().unwrap())); - - // IPv6 链路本地地址 - assert!(IpUtils::is_link_local_ip(&"fe80::1".parse().unwrap())); - assert!(IpUtils::is_link_local_ip(&"fe80::abcd:1234:5678:9abc".parse().unwrap())); - - // 非链路本地地址 - assert!(!IpUtils::is_link_local_ip(&"192.168.1.1".parse().unwrap())); - assert!(!IpUtils::is_link_local_ip(&"2001:db8::1".parse().unwrap())); - } - - #[test] - fn test_is_documentation_ip() { - // IPv4 文档地址 - assert!(IpUtils::is_documentation_ip(&"192.0.2.1".parse().unwrap())); - assert!(IpUtils::is_documentation_ip(&"198.51.100.1".parse().unwrap())); - assert!(IpUtils::is_documentation_ip(&"203.0.113.1".parse().unwrap())); - - // IPv6 文档地址 - assert!(IpUtils::is_documentation_ip(&"2001:db8::1".parse().unwrap())); - - // 非文档地址 - assert!(!IpUtils::is_documentation_ip(&"8.8.8.8".parse().unwrap())); - assert!(!IpUtils::is_documentation_ip(&"2001:4860::1".parse().unwrap())); - } - - #[test] - fn test_parse_ip_or_cidr() { - // 测试单个 IP - let result = IpUtils::parse_ip_or_cidr("192.168.1.1"); - assert!(result.is_ok()); - - // 测试 CIDR - let result = IpUtils::parse_ip_or_cidr("192.168.1.0/24"); - assert!(result.is_ok()); - - // 测试 IPv6 - let result = IpUtils::parse_ip_or_cidr("2001:db8::1"); - assert!(result.is_ok()); - - let result = IpUtils::parse_ip_or_cidr("2001:db8::/32"); - assert!(result.is_ok()); - - // 测试无效输入 - let result = IpUtils::parse_ip_or_cidr("invalid"); - assert!(result.is_err()); - } - - #[test] - fn test_parse_ip_list() { - // 测试有效的 IP 列表 - let result = IpUtils::parse_ip_list("192.168.1.1, 10.0.0.1, 8.8.8.8"); - assert!(result.is_ok()); - - let ips = result.unwrap(); - assert_eq!(ips.len(), 3); - assert_eq!(ips[0], IpAddr::from_str("192.168.1.1").unwrap()); - assert_eq!(ips[1], IpAddr::from_str("10.0.0.1").unwrap()); - assert_eq!(ips[2], IpAddr::from_str("8.8.8.8").unwrap()); - - // 测试带空格的 IP 列表 - let result = IpUtils::parse_ip_list("192.168.1.1,10.0.0.1"); - assert!(result.is_ok()); - assert_eq!(result.unwrap().len(), 2); - - // 测试空列表 - let result = IpUtils::parse_ip_list(""); - assert!(result.is_ok()); - assert!(result.unwrap().is_empty()); - - // 测试无效 IP - let result = IpUtils::parse_ip_list("192.168.1.1, invalid"); - assert!(result.is_err()); - } - - #[test] - fn test_parse_network_list() { - // 测试有效的网络列表 - let result = IpUtils::parse_network_list("192.168.1.0/24, 10.0.0.0/8"); - assert!(result.is_ok()); - - let networks = result.unwrap(); - assert_eq!(networks.len(), 2); - - // 测试单个 IP(会被当作/32 或/128 网络) - let result = IpUtils::parse_network_list("192.168.1.1"); - assert!(result.is_ok()); - assert_eq!(result.unwrap().len(), 1); - - // 测试无效网络 - let result = IpUtils::parse_network_list("192.168.1.0/24, invalid"); - assert!(result.is_err()); - } - - #[test] - fn test_ip_in_networks() { - let networks = vec!["10.0.0.0/8".parse().unwrap(), "192.168.1.0/24".parse().unwrap()]; - - let ip_in_network: IpAddr = "10.0.1.1".parse().unwrap(); - let ip_in_network2: IpAddr = "192.168.1.100".parse().unwrap(); - let ip_not_in_network: IpAddr = "8.8.8.8".parse().unwrap(); - - assert!(IpUtils::ip_in_networks(&ip_in_network, &networks)); - assert!(IpUtils::ip_in_networks(&ip_in_network2, &networks)); - assert!(!IpUtils::ip_in_networks(&ip_not_in_network, &networks)); - } - - #[test] - fn test_get_ip_type() { - assert_eq!(IpUtils::get_ip_type(&"10.0.0.1".parse().unwrap()), "private"); - assert_eq!(IpUtils::get_ip_type(&"127.0.0.1".parse().unwrap()), "loopback"); - assert_eq!(IpUtils::get_ip_type(&"169.254.0.1".parse().unwrap()), "link_local"); - assert_eq!(IpUtils::get_ip_type(&"192.0.2.1".parse().unwrap()), "documentation"); - assert_eq!(IpUtils::get_ip_type(&"224.0.0.1".parse().unwrap()), "reserved"); - assert_eq!(IpUtils::get_ip_type(&"8.8.8.8".parse().unwrap()), "public"); - } - - #[test] - fn test_canonical_ip() { - // 测试 IPv4 - let ipv4: IpAddr = "192.168.001.001".parse().unwrap(); - assert_eq!(IpUtils::canonical_ip(&ipv4), "192.168.1.1"); - - // 测试 IPv6 压缩 - let ipv6_full: IpAddr = "2001:0db8:0000:0000:0000:0000:0000:0001".parse().unwrap(); - let ipv6_compressed: IpAddr = "2001:db8::1".parse().unwrap(); - - assert_eq!(IpUtils::canonical_ip(&ipv6_full), "2001:db8::1"); - assert_eq!(IpUtils::canonical_ip(&ipv6_compressed), "2001:db8::1"); - - // 测试包含多个零段的 IPv6 - let ipv6_multi_zero: IpAddr = "2001:0db8:0000:0000:abcd:0000:0000:1234".parse().unwrap(); - assert_eq!(IpUtils::canonical_ip(&ipv6_multi_zero), "2001:db8::abcd:0:0:1234"); - } -} diff --git a/crates/trusted-proxies/test/unit/validation_tests.rs b/crates/trusted-proxies/test/unit/validation_tests.rs deleted file mode 100644 index 44954b93..00000000 --- a/crates/trusted-proxies/test/unit/validation_tests.rs +++ /dev/null @@ -1,637 +0,0 @@ -// Copyright 2024 RustFS Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Validation utility unit tests - -#[cfg(test)] -mod tests { - use http::HeaderMap; - use std::net::IpAddr; - - use crate::utils::ip::IpUtils; - use crate::utils::validation::ValidationUtils; - - /// 测试电子邮件验证 - #[test] - fn test_email_validation() { - // 有效的电子邮件地址 - assert!(ValidationUtils::is_valid_email("user@example.com")); - assert!(ValidationUtils::is_valid_email("first.last@example.co.uk")); - assert!(ValidationUtils::is_valid_email("user123@example.org")); - assert!(ValidationUtils::is_valid_email("user+tag@example.com")); - assert!(ValidationUtils::is_valid_email("user_name@example-domain.com")); - - // 无效的电子邮件地址 - assert!(!ValidationUtils::is_valid_email("")); - assert!(!ValidationUtils::is_valid_email("invalid-email")); - assert!(!ValidationUtils::is_valid_email("user@")); - assert!(!ValidationUtils::is_valid_email("@example.com")); - assert!(!ValidationUtils::is_valid_email("user@.com")); - assert!(!ValidationUtils::is_valid_email("user@example.")); - assert!(!ValidationUtils::is_valid_email("user@example..com")); - assert!(!ValidationUtils::is_valid_email("user@example_com")); - assert!(!ValidationUtils::is_valid_email("user@[127.0.0.1]")); - assert!(!ValidationUtils::is_valid_email("user name@example.com")); - assert!(!ValidationUtils::is_valid_email("user@exa mple.com")); - assert!(!ValidationUtils::is_valid_email("user@example.c")); - } - - /// 测试 URL 验证 - #[test] - fn test_url_validation() { - // 有效的 URL - assert!(ValidationUtils::is_valid_url("https://example.com")); - assert!(ValidationUtils::is_valid_url("http://example.com")); - assert!(ValidationUtils::is_valid_url("example.com")); - assert!(ValidationUtils::is_valid_url("sub.example.com")); - assert!(ValidationUtils::is_valid_url("example.co.uk")); - assert!(ValidationUtils::is_valid_url("example.com/path")); - assert!(ValidationUtils::is_valid_url("example.com/path/to/resource")); - assert!(ValidationUtils::is_valid_url("example.com/?query=param")); - assert!(ValidationUtils::is_valid_url("sub-domain.example-domain.com")); - - // 无效的 URL - assert!(!ValidationUtils::is_valid_url("")); - assert!(!ValidationUtils::is_valid_url("invalid")); - assert!(!ValidationUtils::is_valid_url("example")); - assert!(!ValidationUtils::is_valid_url("example.")); - assert!(!ValidationUtils::is_valid_url(".com")); - assert!(!ValidationUtils::is_valid_url("http://")); - assert!(!ValidationUtils::is_valid_url("https://")); - assert!(!ValidationUtils::is_valid_url("://example.com")); - assert!(!ValidationUtils::is_valid_url("example..com")); - assert!(!ValidationUtils::is_valid_url("-example.com")); - assert!(!ValidationUtils::is_valid_url("example-.com")); - assert!(!ValidationUtils::is_valid_url("example_com")); - } - - /// 测试 X-Forwarded-For 头部验证 - #[test] - fn test_x_forwarded_for_validation() { - // 有效的 X-Forwarded-For 头部 - assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195")); - assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195, 198.51.100.1")); - assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195,198.51.100.1,10.0.1.100")); - assert!(ValidationUtils::validate_x_forwarded_for("2001:db8::1")); - assert!(ValidationUtils::validate_x_forwarded_for("2001:db8::1, 2001:db8::2")); - assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195:8080, 198.51.100.1:443")); // 带端口 - - // 无效的 X-Forwarded-For 头部 - assert!(!ValidationUtils::validate_x_forwarded_for("")); // 空字符串 - assert!(!ValidationUtils::validate_x_forwarded_for(" ")); // 只有空格 - assert!(!ValidationUtils::validate_x_forwarded_for("invalid")); // 无效 IP - assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195, invalid")); // 部分无效 - assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195, ")); // 尾部逗号加空格 - assert!(!ValidationUtils::validate_x_forwarded_for(",203.0.113.195")); // 开头逗号 - assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195,,198.51.100.1")); // 连续逗号 - assert!(!ValidationUtils::validate_x_forwarded_for("256.256.256.256")); // 超出范围的 IP - } - - /// 测试 Forwarded 头部验证 (RFC 7239) - #[test] - fn test_forwarded_header_validation() { - // 有效的 Forwarded 头部 - assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60")); - assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60;proto=http")); - assert!(ValidationUtils::validate_forwarded_header("for=\"[2001:db8:cafe::17]\";proto=https")); - assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.43, for=198.51.100.17")); - assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60;proto=http;by=203.0.113.43")); - assert!(ValidationUtils::validate_forwarded_header( - "by=203.0.113.43;for=192.0.2.60;host=example.com;proto=https" - )); - - // 无效的 Forwarded 头部 - assert!(!ValidationUtils::validate_forwarded_header("")); // 空字符串 - assert!(!ValidationUtils::validate_forwarded_header(" ")); // 只有空格 - assert!(!ValidationUtils::validate_forwarded_header("invalid")); // 无效格式 - assert!(!ValidationUtils::validate_forwarded_header("for=192.0.2.60 proto=http")); // 缺少分号 - assert!(!ValidationUtils::validate_forwarded_header("for;192.0.2.60")); // 缺少等号 - assert!(!ValidationUtils::validate_forwarded_header("=192.0.2.60")); // 缺少键 - assert!(!ValidationUtils::validate_forwarded_header("for=")); // 缺少值 - } - - /// 测试 IP 范围验证 - #[test] - fn test_ip_in_range_validation() { - let cidr_ranges = vec![ - "10.0.0.0/8".to_string(), - "192.168.0.0/16".to_string(), - "172.16.0.0/12".to_string(), - "2001:db8::/32".to_string(), - ]; - - // IP 在范围内 - let ip_in_range: IpAddr = "10.0.1.1".parse().unwrap(); - assert!(ValidationUtils::validate_ip_in_range(&ip_in_range, &cidr_ranges)); - - let ip_in_range2: IpAddr = "192.168.1.100".parse().unwrap(); - assert!(ValidationUtils::validate_ip_in_range(&ip_in_range2, &cidr_ranges)); - - let ip_in_range3: IpAddr = "172.16.0.1".parse().unwrap(); - assert!(ValidationUtils::validate_ip_in_range(&ip_in_range3, &cidr_ranges)); - - let ipv6_in_range: IpAddr = "2001:db8::1".parse().unwrap(); - assert!(ValidationUtils::validate_ip_in_range(&ipv6_in_range, &cidr_ranges)); - - // IP 不在范围内 - let ip_not_in_range: IpAddr = "8.8.8.8".parse().unwrap(); - assert!(!ValidationUtils::validate_ip_in_range(&ip_not_in_range, &cidr_ranges)); - - let ip_not_in_range2: IpAddr = "203.0.113.1".parse().unwrap(); - assert!(!ValidationUtils::validate_ip_in_range(&ip_not_in_range2, &cidr_ranges)); - - let ipv6_not_in_range: IpAddr = "2001:4860::1".parse().unwrap(); - assert!(!ValidationUtils::validate_ip_in_range(&ipv6_not_in_range, &cidr_ranges)); - - // 空范围列表 - assert!(!ValidationUtils::validate_ip_in_range(&ip_in_range, &Vec::new())); - - // 无效的 CIDR 范围(应该被忽略) - let invalid_ranges = vec![ - "invalid".to_string(), - "10.0.0.0/8".to_string(), // 这个有效 - ]; - - let test_ip: IpAddr = "10.0.1.1".parse().unwrap(); - // 即使有无效范围,只要有一个有效范围包含 IP,就应该返回 true - assert!(ValidationUtils::validate_ip_in_range(&test_ip, &invalid_ranges)); - } - - /// 测试头部值验证 - #[test] - fn test_header_value_validation() { - // 有效的头部值 - assert!(ValidationUtils::validate_header_value("text/plain")); - assert!(ValidationUtils::validate_header_value("application/json; charset=utf-8")); - assert!(ValidationUtils::validate_header_value("Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9")); - assert!(ValidationUtils::validate_header_value("127.0.0.1:8080")); - assert!(ValidationUtils::validate_header_value("")); // 空字符串是有效的 - assert!(ValidationUtils::validate_header_value("normal text with spaces")); // 普通文本 - assert!(ValidationUtils::validate_header_value("value\twith\ttabs")); // 包含制表符 - - // 长但有效的头部值(边界情况) - let long_value = "a".repeat(8192); - assert!(ValidationUtils::validate_header_value(&long_value)); - - // 无效的头部值 - let too_long_value = "a".repeat(8193); // 超过长度限制 - assert!(!ValidationUtils::validate_header_value(&too_long_value)); - - // 包含控制字符(除了制表符、换行符、回车符) - assert!(!ValidationUtils::validate_header_value("value\x00with_null")); // 空字符 - assert!(!ValidationUtils::validate_header_value("value\x01with_soh")); // 标题开始 - assert!(!ValidationUtils::validate_header_value("value\x1fwith_us")); // 单元分隔符 - - // 换行符和回车符是允许的(在某些上下文中) - assert!(ValidationUtils::validate_header_value("line1\nline2")); - assert!(ValidationUtils::validate_header_value("line1\r\nline2")); - } - - /// 测试头部映射验证 - #[test] - fn test_headers_validation() { - let mut valid_headers = HeaderMap::new(); - valid_headers.insert("Content-Type", "application/json".parse().unwrap()); - valid_headers.insert("Authorization", "Bearer token123".parse().unwrap()); - valid_headers.insert("X-Forwarded-For", "203.0.113.195".parse().unwrap()); - - assert!(ValidationUtils::validate_headers(&valid_headers)); - - // 测试头部名称过长 - let mut invalid_headers = HeaderMap::new(); - let long_name = "X-".to_string() + &"A".repeat(300); // 超过 256 字符 - invalid_headers.insert(long_name, "value".parse().unwrap()); - - assert!(!ValidationUtils::validate_headers(&invalid_headers)); - - // 测试头部值过长 - let mut invalid_headers2 = HeaderMap::new(); - let long_value = "A".repeat(8193); // 超过 8192 字节 - invalid_headers2.insert("X-Custom-Header", long_value.parse().unwrap()); - - assert!(!ValidationUtils::validate_headers(&invalid_headers2)); - - // 测试二进制数据(无法转换为字符串) - let mut binary_headers = HeaderMap::new(); - let binary_data = vec![0x00, 0x01, 0x02, 0x03]; - binary_headers.insert("X-Binary-Data", http::HeaderValue::from_bytes(&binary_data).unwrap()); - - // 二进制数据应该通过验证(只要长度不超过限制) - assert!(ValidationUtils::validate_headers(&binary_headers)); - - // 测试过长的二进制数据 - let mut long_binary_headers = HeaderMap::new(); - let long_binary_data = vec![0x00; 8193]; // 超过 8192 字节 - long_binary_headers.insert("X-Long-Binary", http::HeaderValue::from_bytes(&long_binary_data).unwrap()); - - assert!(!ValidationUtils::validate_headers(&long_binary_headers)); - } - - /// 测试端口号验证 - #[test] - fn test_port_validation() { - // 有效端口号 - assert!(ValidationUtils::validate_port(1)); - assert!(ValidationUtils::validate_port(80)); - assert!(ValidationUtils::validate_port(443)); - assert!(ValidationUtils::validate_port(8080)); - assert!(ValidationUtils::validate_port(65535)); - - // 无效端口号 - assert!(!ValidationUtils::validate_port(0)); // 端口 0 是保留的 - assert!(!ValidationUtils::validate_port(65536)); // 超过最大值 - assert!(!ValidationUtils::validate_port(70000)); // 远超过最大值 - } - - /// 测试 CIDR 表示法验证 - #[test] - fn test_cidr_validation() { - // 有效的 CIDR 表示法 - assert!(ValidationUtils::validate_cidr("192.168.1.0/24")); - assert!(ValidationUtils::validate_cidr("10.0.0.0/8")); - assert!(ValidationUtils::validate_cidr("0.0.0.0/0")); // 默认路由 - assert!(ValidationUtils::validate_cidr("2001:db8::/32")); - assert!(ValidationUtils::validate_cidr("::/0")); // IPv6 默认路由 - assert!(ValidationUtils::validate_cidr("192.168.1.1/32")); // 单个主机 - assert!(ValidationUtils::validate_cidr("2001:db8::1/128")); // 单个 IPv6 主机 - - // 无效的 CIDR 表示法 - assert!(!ValidationUtils::validate_cidr("")); // 空字符串 - assert!(!ValidationUtils::validate_cidr("invalid")); // 无效格式 - assert!(!ValidationUtils::validate_cidr("192.168.1.0")); // 缺少前缀长度 - assert!(!ValidationUtils::validate_cidr("192.168.1.0/33")); // 前缀长度过大 - assert!(!ValidationUtils::validate_cidr("256.256.256.256/24")); // 无效 IP - assert!(!ValidationUtils::validate_cidr("192.168.1.0/24/extra")); // 多余的部分 - assert!(!ValidationUtils::validate_cidr("192.168.1.0/-1")); // 负的前缀长度 - assert!(!ValidationUtils::validate_cidr("192.168.1.0/abc")); // 非数字前缀长度 - } - - /// 测试代理链长度验证 - #[test] - fn test_proxy_chain_length_validation() { - let chain = vec![ - "203.0.113.195".parse().unwrap(), - "198.51.100.1".parse().unwrap(), - "10.0.1.100".parse().unwrap(), - ]; - - // 链长度在限制内 - assert!(ValidationUtils::validate_proxy_chain_length(&chain, 3)); - assert!(ValidationUtils::validate_proxy_chain_length(&chain, 5)); - assert!(ValidationUtils::validate_proxy_chain_length(&chain, 10)); - - // 链长度超过限制 - assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 2)); - assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 1)); - assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 0)); - - // 空链 - let empty_chain: Vec = Vec::new(); - assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 0)); - assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 1)); - assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 10)); - } - - /// 测试代理链连续性验证 - #[test] - fn test_proxy_chain_continuity_validation() { - // 连续链(无重复相邻 IP) - let continuous_chain = vec![ - "203.0.113.195".parse().unwrap(), - "198.51.100.1".parse().unwrap(), - "10.0.1.100".parse().unwrap(), - ]; - assert!(ValidationUtils::validate_proxy_chain_continuity(&continuous_chain)); - - // 不连续链(有重复相邻 IP) - let discontinuous_chain = vec![ - "203.0.113.195".parse().unwrap(), - "198.51.100.1".parse().unwrap(), - "198.51.100.1".parse().unwrap(), // 重复 - "10.0.1.100".parse().unwrap(), - ]; - assert!(!ValidationUtils::validate_proxy_chain_continuity(&discontinuous_chain)); - - // 短链(应该总是连续的) - let short_chain = vec!["203.0.113.195".parse().unwrap()]; - assert!(ValidationUtils::validate_proxy_chain_continuity(&short_chain)); - - let two_item_chain = vec!["203.0.113.195".parse().unwrap(), "198.51.100.1".parse().unwrap()]; - assert!(ValidationUtils::validate_proxy_chain_continuity(&two_item_chain)); - - // 空链(应该总是连续的) - let empty_chain: Vec = Vec::new(); - assert!(ValidationUtils::validate_proxy_chain_continuity(&empty_chain)); - - // 有多个重复的链 - let multi_duplicate_chain = vec![ - "203.0.113.195".parse().unwrap(), - "203.0.113.195".parse().unwrap(), // 重复 1 - "198.51.100.1".parse().unwrap(), - "198.51.100.1".parse().unwrap(), // 重复 2 - ]; - assert!(!ValidationUtils::validate_proxy_chain_continuity(&multi_duplicate_chain)); - } - - /// 测试安全字符串验证 - #[test] - fn test_safe_string_validation() { - // 安全字符串 - assert!(ValidationUtils::is_safe_string("example")); - assert!(ValidationUtils::is_safe_string("example123")); - assert!(ValidationUtils::is_safe_string("example-test")); - assert!(ValidationUtils::is_safe_string("example.test")); - assert!(ValidationUtils::is_safe_string("example~test")); - assert!(ValidationUtils::is_safe_string("http://example.com/path")); - assert!(ValidationUtils::is_safe_string("https://example.com/?query=param")); - assert!(ValidationUtils::is_safe_string("user@example.com")); - assert!(ValidationUtils::is_safe_string("192.168.1.1:8080")); - assert!(ValidationUtils::is_safe_string("[2001:db8::1]:8080")); - - // 不安全字符串 - assert!(!ValidationUtils::is_safe_string("")); // 空字符串 - assert!(!ValidationUtils::is_safe_string("example test")); // 包含空格 - assert!(!ValidationUtils::is_safe_string("example\ttest")); // 包含制表符 - assert!(!ValidationUtils::is_safe_string("example\ntest")); // 包含换行符 - assert!(!ValidationUtils::is_safe_string("example")); // 包含尖括号 - assert!(!ValidationUtils::is_safe_string("example\"test")); // 包含双引号 - assert!(!ValidationUtils::is_safe_string("example'test")); // 包含单引号 - assert!(!ValidationUtils::is_safe_string("example\\test")); // 包含反斜杠 - assert!(!ValidationUtils::is_safe_string("example`test")); // 包含反引号 - assert!(!ValidationUtils::is_safe_string("example|test")); // 包含竖线 - assert!(!ValidationUtils::is_safe_string("example$test")); // 包含美元符号 - assert!(!ValidationUtils::is_safe_string("example%test")); // 包含百分号 - assert!(!ValidationUtils::is_safe_string("example^test")); // 包含脱字符 - assert!(!ValidationUtils::is_safe_string("example&test")); // 包含和号 - assert!(!ValidationUtils::is_safe_string("example(test")); // 包含括号 - assert!(!ValidationUtils::is_safe_string("example)test")); // 包含括号 - assert!(!ValidationUtils::is_safe_string("example[test")); // 包含方括号 - assert!(!ValidationUtils::is_safe_string("example]test")); // 包含方括号 - assert!(!ValidationUtils::is_safe_string("example{test")); // 包含花括号 - assert!(!ValidationUtils::is_safe_string("example}test")); // 包含花括号 - } - - /// 测试速率限制参数验证 - #[test] - fn test_rate_limit_params_validation() { - // 有效的速率限制参数 - assert!(ValidationUtils::validate_rate_limit_params(1, 1)); // 最小值 - assert!(ValidationUtils::validate_rate_limit_params(100, 60)); // 典型值 - assert!(ValidationUtils::validate_rate_limit_params(10000, 86400)); // 最大值 - - // 无效的速率限制参数 - assert!(!ValidationUtils::validate_rate_limit_params(0, 60)); // 请求数为 0 - assert!(!ValidationUtils::validate_rate_limit_params(10001, 60)); // 请求数超过最大值 - assert!(!ValidationUtils::validate_rate_limit_params(100, 0)); // 周期为 0 - assert!(!ValidationUtils::validate_rate_limit_params(100, 86401)); // 周期超过最大值 - assert!(!ValidationUtils::validate_rate_limit_params(0, 0)); // 两者都为 0 - assert!(!ValidationUtils::validate_rate_limit_params(100001, 100000)); // 两者都超过最大值 - } - - /// 测试缓存参数验证 - #[test] - fn test_cache_params_validation() { - // 有效的缓存参数 - assert!(ValidationUtils::validate_cache_params(1, 1)); // 最小值 - assert!(ValidationUtils::validate_cache_params(10000, 300)); // 典型值 - assert!(ValidationUtils::validate_cache_params(1000000, 86400)); // 最大值 - - // 无效的缓存参数 - assert!(!ValidationUtils::validate_cache_params(0, 300)); // 容量为 0 - assert!(!ValidationUtils::validate_cache_params(1000001, 300)); // 容量超过最大值 - assert!(!ValidationUtils::validate_cache_params(10000, 0)); // TTL 为 0 - assert!(!ValidationUtils::validate_cache_params(10000, 86401)); // TTL 超过最大值 - assert!(!ValidationUtils::validate_cache_params(0, 0)); // 两者都为 0 - assert!(!ValidationUtils::validate_cache_params(2000000, 100000)); // 两者都超过最大值 - } - - /// 测试敏感数据脱敏 - #[test] - fn test_sensitive_data_masking() { - let sensitive_patterns = vec!["password", "token", "secret", "authorization", "api_key"]; - - // 测试各种敏感字段的脱敏 - let test_cases = vec![ - ( - r#"{"username":"john","password":"secret123"}"#, - r#"{"username":"john","password:[REDACTED]"}"#, - ), - (r#"token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9&user=john"#, r#"token:[REDACTED]&user=john"#), - ( - r#"Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"#, - r#"Authorization:[REDACTED]"#, - ), - (r#"api_key=sk_test_1234567890abcdef"#, r#"api_key:[REDACTED]"#), - (r#"secret_key=abc123&public_key=xyz789"#, r#"secret_key:[REDACTED]&public_key=xyz789"#), - ( - r#"password=123&password_confirmation=123"#, - r#"password:[REDACTED]&password_confirmation:[REDACTED]"#, - ), - ]; - - for (input, expected) in test_cases { - let result = ValidationUtils::mask_sensitive_data(input, &sensitive_patterns); - assert_eq!(result, expected, "Failed to mask: {}", input); - } - - // 测试不包含敏感数据的情况 - let safe_data = r#"{"name":"John","age":30,"city":"New York"}"#; - let result = ValidationUtils::mask_sensitive_data(safe_data, &sensitive_patterns); - assert_eq!(result, safe_data); - - // 测试空模式列表 - let sensitive_data = r#"password=secret123"#; - let result = ValidationUtils::mask_sensitive_data(sensitive_data, &Vec::new()); - assert_eq!(result, sensitive_data); - - // 测试空输入 - let result = ValidationUtils::mask_sensitive_data("", &sensitive_patterns); - assert_eq!(result, ""); - } - - /// 测试组合验证场景 - #[test] - fn test_combined_validation_scenarios() { - // 场景 1:完整的代理请求验证 - let proxy_chain = vec![ - "203.0.113.195".parse().unwrap(), - "198.51.100.1".parse().unwrap(), - "10.0.1.100".parse().unwrap(), - ]; - - assert!(ValidationUtils::validate_proxy_chain_length(&proxy_chain, 10)); - assert!(ValidationUtils::validate_proxy_chain_continuity(&proxy_chain)); - - // 场景 2:包含无效数据的头部验证 - let mut headers = HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse().unwrap()); - headers.insert("X-Forwarded-For", "203.0.113.195, invalid, 10.0.1.100".parse().unwrap()); - - // 头部映射本身是有效的(即使包含无效的 X-Forwarded-For) - assert!(ValidationUtils::validate_headers(&headers)); - - // 但 X-Forwarded-For 内容无效 - let xff_value = headers.get("X-Forwarded-For").unwrap().to_str().unwrap(); - assert!(!ValidationUtils::validate_x_forwarded_for(xff_value)); - - // 场景 3:配置参数验证组合 - let cache_capacity = 10000; - let cache_ttl = 300; - let rate_limit_requests = 100; - let rate_limit_period = 60; - - assert!(ValidationUtils::validate_cache_params(cache_capacity, cache_ttl)); - assert!(ValidationUtils::validate_rate_limit_params(rate_limit_requests, rate_limit_period)); - - // 场景 4:IP 和 CIDR 验证组合 - let ip: IpAddr = "10.0.1.1".parse().unwrap(); - let cidr = "10.0.0.0/8"; - - assert!(IpUtils::is_private_ip(&ip)); - assert!(ValidationUtils::validate_cidr(cidr)); - assert!(ValidationUtils::validate_ip_in_range(&ip, &[cidr.to_string()])); - } - - /// 测试边缘情况和边界值 - #[test] - fn test_edge_cases_and_boundaries() { - // 测试头部值的边界长度 - let max_length_value = "a".repeat(8192); - let over_length_value = "a".repeat(8193); - - assert!(ValidationUtils::validate_header_value(&max_length_value)); - assert!(!ValidationUtils::validate_header_value(&over_length_value)); - - // 测试端口边界值 - assert!(!ValidationUtils::validate_port(0)); // 最小无效值 - assert!(ValidationUtils::validate_port(1)); // 最小有效值 - assert!(ValidationUtils::validate_port(65535)); // 最大有效值 - assert!(!ValidationUtils::validate_port(65536)); // 超过最大值 - - // 测试 CIDR 前缀长度边界值 - assert!(ValidationUtils::validate_cidr("192.168.1.0/0")); // 最小有效前缀 - assert!(ValidationUtils::validate_cidr("192.168.1.0/32")); // 最大有效前缀 - assert!(!ValidationUtils::validate_cidr("192.168.1.0/33")); // 超过最大值 - - // IPv6 CIDR 前缀长度 - assert!(ValidationUtils::validate_cidr("2001:db8::/0")); // 最小有效前缀 - assert!(ValidationUtils::validate_cidr("2001:db8::/128")); // 最大有效前缀 - - // 测试代理链边界情况 - let empty_chain: Vec = Vec::new(); - assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 0)); - assert!(ValidationUtils::validate_proxy_chain_continuity(&empty_chain)); - - // 测试单个 IP 的链 - let single_ip_chain = vec!["192.168.1.1".parse().unwrap()]; - assert!(ValidationUtils::validate_proxy_chain_length(&single_ip_chain, 1)); - assert!(ValidationUtils::validate_proxy_chain_continuity(&single_ip_chain)); - - // 测试速率限制边界值 - assert!(!ValidationUtils::validate_rate_limit_params(0, 60)); - assert!(ValidationUtils::validate_rate_limit_params(1, 1)); - assert!(ValidationUtils::validate_rate_limit_params(10000, 86400)); - assert!(!ValidationUtils::validate_rate_limit_params(10001, 86400)); - assert!(!ValidationUtils::validate_rate_limit_params(10000, 86401)); - - // 测试缓存参数边界值 - assert!(!ValidationUtils::validate_cache_params(0, 300)); - assert!(ValidationUtils::validate_cache_params(1, 1)); - assert!(ValidationUtils::validate_cache_params(1000000, 86400)); - assert!(!ValidationUtils::validate_cache_params(1000001, 86400)); - assert!(!ValidationUtils::validate_cache_params(1000000, 86401)); - } - - /// 测试性能敏感场景 - #[test] - fn test_performance_sensitive_scenarios() { - // 测试长代理链的处理 - let mut long_chain = Vec::new(); - for i in 0..100 { - let ip = format!("10.0.{}.1", i % 256).parse().unwrap(); - long_chain.push(ip); - } - - // 应该能快速处理长链 - assert!(ValidationUtils::validate_proxy_chain_length(&long_chain, 100)); - assert!(ValidationUtils::validate_proxy_chain_continuity(&long_chain)); - - // 测试大量 CIDR 范围的验证 - let mut cidr_ranges = Vec::new(); - for i in 0..1000 { - let cidr = format!("10.{}.0.0/16", i % 256); - cidr_ranges.push(cidr); - } - - let test_ip: IpAddr = "10.128.1.1".parse().unwrap(); - // 应该能快速在大范围列表中查找 - let start = std::time::Instant::now(); - let result = ValidationUtils::validate_ip_in_range(&test_ip, &cidr_ranges); - let duration = start.elapsed(); - - assert!(result); - // 验证时间应该在合理范围内(比如小于 10 毫秒) - assert!(duration < std::time::Duration::from_millis(10)); - - // 测试头部值验证的性能 - let large_header_value = "x".repeat(10000); // 超过 8192,应该快速拒绝 - let start = std::time::Instant::now(); - let result = ValidationUtils::validate_header_value(&large_header_value); - let duration = start.elapsed(); - - assert!(!result); // 应该拒绝 - assert!(duration < std::time::Duration::from_millis(1)); // 应该非常快 - } - - /// 测试实际代理场景模拟 - #[test] - fn test_real_world_proxy_scenarios() { - // 场景 1:典型的反向代理配置 - let typical_xff = "203.0.113.195, 198.51.100.1, 10.0.1.100"; - assert!(ValidationUtils::validate_x_forwarded_for(typical_xff)); - - let typical_proxy_chain: Vec = typical_xff.split(',').map(|s| s.trim().parse().unwrap()).collect(); - - assert_eq!(typical_proxy_chain.len(), 3); - assert!(ValidationUtils::validate_proxy_chain_length(&typical_proxy_chain, 10)); - assert!(ValidationUtils::validate_proxy_chain_continuity(&typical_proxy_chain)); - - // 场景 2:负载均衡器场景 - let lb_scenario = "2001:db8::1, 203.0.113.195, 198.51.100.1"; - assert!(ValidationUtils::validate_x_forwarded_for(lb_scenario)); - - // 场景 3:可能被攻击的头部 - let attack_headers = vec![ - ("X-Forwarded-For", "127.0.0.1, 8.8.8.8, 192.168.1.1"), - ("X-Real-IP", "8.8.8.8"), - ("X-Forwarded-Host", "evil.com"), - ]; - - let mut headers = HeaderMap::new(); - for (name, value) in attack_headers { - headers.insert(name, value.parse().unwrap()); - } - - // 头部格式本身应该是有效的 - assert!(ValidationUtils::validate_headers(&headers)); - - // 但内容可能需要进一步验证 - let xff_value = headers.get("X-Forwarded-For").unwrap().to_str().unwrap(); - assert!(ValidationUtils::validate_x_forwarded_for(xff_value)); - - // 场景 4:RFC 7239 格式 - let rfc7239_header = "for=192.0.2.60;proto=https;by=203.0.113.43"; - assert!(ValidationUtils::validate_forwarded_header(rfc7239_header)); - } -} diff --git a/crates/trusted-proxies/test/unit/validator_tests.rs b/crates/trusted-proxies/test/unit/validator_tests.rs deleted file mode 100644 index 2094d727..00000000 --- a/crates/trusted-proxies/test/unit/validator_tests.rs +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright 2024 RustFS Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Proxy validator unit tests - -#[cfg(test)] -mod tests { - use std::net::{IpAddr, SocketAddr}; - use std::str::FromStr; - - use axum::http::HeaderMap; - - use crate::config::{TrustedProxy, TrustedProxyConfig, ValidationMode}; - use crate::proxy::chain::ProxyChainAnalyzer; - use crate::proxy::validator::{ClientInfo, ProxyValidator}; - - fn create_test_config() -> TrustedProxyConfig { - let proxies = vec![ - TrustedProxy::Single("192.168.1.100".parse().unwrap()), - TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), - TrustedProxy::Cidr("172.16.0.0/12".parse().unwrap()), - ]; - - TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 5, true, vec![]) - } - - #[test] - fn test_client_info_direct() { - let addr = SocketAddr::new(IpAddr::from([192, 168, 1, 1]), 8080); - let client_info = ClientInfo::direct(addr); - - assert_eq!(client_info.real_ip, IpAddr::from([192, 168, 1, 1])); - assert!(client_info.forwarded_host.is_none()); - assert!(client_info.forwarded_proto.is_none()); - assert!(!client_info.is_from_trusted_proxy); - assert!(client_info.proxy_ip.is_none()); - assert_eq!(client_info.proxy_hops, 0); - assert_eq!(client_info.validation_mode, ValidationMode::Lenient); - assert!(client_info.warnings.is_empty()); - } - - #[test] - fn test_parse_x_forwarded_for() { - use crate::proxy::validator::ProxyValidator; - - // 测试有效的 X-Forwarded-For 头部 - let header_value = "203.0.113.195, 198.51.100.1, 10.0.1.100"; - let result = ProxyValidator::parse_x_forwarded_for(header_value); - - assert_eq!(result.len(), 3); - assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap()); - assert_eq!(result[1], IpAddr::from_str("198.51.100.1").unwrap()); - assert_eq!(result[2], IpAddr::from_str("10.0.1.100").unwrap()); - - // 测试带端口的 IP - let header_value_with_ports = "203.0.113.195:8080, 198.51.100.1:443"; - let result = ProxyValidator::parse_x_forwarded_for(header_value_with_ports); - - assert_eq!(result.len(), 2); - assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap()); - assert_eq!(result[1], IpAddr::from_str("198.51.100.1").unwrap()); - - // 测试空值 - let empty_result = ProxyValidator::parse_x_forwarded_for(""); - assert!(empty_result.is_empty()); - - // 测试无效 IP - let invalid_result = ProxyValidator::parse_x_forwarded_for("invalid, 203.0.113.195"); - assert_eq!(invalid_result.len(), 1); // 无效项被跳过 - } - - #[test] - fn test_proxy_chain_analyzer_lenient() { - let config = create_test_config(); - let analyzer = ProxyChainAnalyzer::new(config.clone()); - - // 测试链:客户端 -> 可信代理 1 -> 可信代理 2 - let chain = vec![ - IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 - IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1 - IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 - ]; - - let current_proxy = IpAddr::from_str("192.168.1.100").unwrap(); - let mut headers = HeaderMap::new(); - - let result = analyzer.analyze_chain(&chain, current_proxy, &headers); - assert!(result.is_ok()); - - let analysis = result.unwrap(); - assert_eq!(analysis.client_ip, IpAddr::from_str("203.0.113.195").unwrap()); - assert_eq!(analysis.hops, 3); - assert!(analysis.is_continuous); - assert_eq!(analysis.validation_mode, ValidationMode::HopByHop); - } - - #[test] - fn test_proxy_chain_analyzer_strict() { - let mut config = create_test_config(); - config.validation_mode = ValidationMode::Strict; - - let analyzer = ProxyChainAnalyzer::new(config); - - // 测试链:客户端 -> 可信代理 1 -> 可信代理 2 (全部可信) - let chain = vec![ - IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 - IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1 - IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 - ]; - - let current_proxy = IpAddr::from_str("192.168.1.100").unwrap(); - let mut headers = HeaderMap::new(); - - let result = analyzer.analyze_chain(&chain, current_proxy, &headers); - assert!(result.is_ok()); - - // 测试链包含不可信代理 - let chain_with_untrusted = vec![ - IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 - IpAddr::from_str("8.8.8.8").unwrap(), // 不可信代理 - IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 - ]; - - let result = analyzer.analyze_chain(&chain_with_untrusted, current_proxy, &headers); - assert!(result.is_err()); - } - - #[test] - fn test_proxy_chain_analyzer_hop_by_hop() { - let config = create_test_config(); - let analyzer = ProxyChainAnalyzer::new(config); - - // 测试链:客户端 -> 不可信代理 -> 可信代理 1 -> 可信代理 2 - let chain = vec![ - IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 - IpAddr::from_str("8.8.8.8").unwrap(), // 不可信代理 - IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1 - IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 - ]; - - let current_proxy = IpAddr::from_str("192.168.1.100").unwrap(); - let mut headers = HeaderMap::new(); - - let result = analyzer.analyze_chain(&chain, current_proxy, &headers); - assert!(result.is_ok()); - - let analysis = result.unwrap(); - // 应该找到客户端 IP (203.0.113.195) - assert_eq!(analysis.client_ip, IpAddr::from_str("203.0.113.195").unwrap()); - // 应该验证 2 跳 (10.0.1.100 和 192.168.1.100) - assert_eq!(analysis.hops, 2); - } - - #[test] - fn test_chain_continuity_check() { - let config = create_test_config(); - let analyzer = ProxyChainAnalyzer::new(config); - - // 测试连续链 - let full_chain = vec![ - IpAddr::from_str("203.0.113.195").unwrap(), - IpAddr::from_str("10.0.1.100").unwrap(), - IpAddr::from_str("192.168.1.100").unwrap(), - ]; - - let trusted_chain = vec![ - IpAddr::from_str("10.0.1.100").unwrap(), - IpAddr::from_str("192.168.1.100").unwrap(), - ]; - - assert!(analyzer.check_chain_continuity(&full_chain, &trusted_chain)); - - // 测试不连续链 - let bad_trusted_chain = vec![IpAddr::from_str("192.168.1.100").unwrap()]; - - assert!(!analyzer.check_chain_continuity(&full_chain, &bad_trusted_chain)); - } - - #[test] - fn test_validate_ip_addresses() { - let config = create_test_config(); - let analyzer = ProxyChainAnalyzer::new(config); - - // 测试有效 IP - let valid_chain = vec![ - IpAddr::from_str("203.0.113.195").unwrap(), - IpAddr::from_str("10.0.1.100").unwrap(), - ]; - - let result = analyzer.validate_ip_addresses(&valid_chain); - assert!(result.is_ok()); - - // 测试未指定地址 - let invalid_chain = vec![IpAddr::from_str("0.0.0.0").unwrap()]; - - let result = analyzer.validate_ip_addresses(&invalid_chain); - assert!(result.is_err()); - - // 测试多播地址 - let multicast_chain = vec![IpAddr::from_str("224.0.0.1").unwrap()]; - - let result = analyzer.validate_ip_addresses(&multicast_chain); - assert!(result.is_err()); - } - - #[test] - fn test_proxy_validator_creation() { - let config = create_test_config(); - let validator = ProxyValidator::new(config, None); - - // 验证器应该成功创建 - assert!(true); // 如果没有 panic,测试通过 - } -} diff --git a/crates/trusted-proxies/tests/integration/api_tests.rs b/crates/trusted-proxies/tests/integration/api_tests.rs new file mode 100644 index 00000000..c49008ad --- /dev/null +++ b/crates/trusted-proxies/tests/integration/api_tests.rs @@ -0,0 +1,49 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use axum::body::Body; +use axum::{routing::get, Router}; +use serde_json::{json}; +use tower::ServiceExt; +use rustfs_trusted_proxies::config::{AppConfig, TrustedProxy, TrustedProxyConfig, ValidationMode}; +use rustfs_trusted_proxies::state::AppState; + +fn create_test_app_state() -> AppState { + let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; + let proxy_config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]); + let config = AppConfig::new( + proxy_config, + rustfs_trusted_proxies::config::CacheConfig::default(), + rustfs_trusted_proxies::config::MonitoringConfig::default(), + rustfs_trusted_proxies::config::CloudConfig::default(), + "127.0.0.1:3000".parse().unwrap(), + ); + AppState { + config: Arc::new(config), + metrics: None, + } +} + +#[tokio::test] +async fn test_health_check_endpoint() { + let state = create_test_app_state(); + let app = Router::new() + .route("/health", get(|| async { axum::response::Json(json!({"status": "healthy"})) })) + .with_state(state); + + let request = axum::http::Request::builder().uri("/health").body(Body::empty()).unwrap(); + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); +} diff --git a/crates/trusted-proxies/tests/integration/cloud_tests.rs b/crates/trusted-proxies/tests/integration/cloud_tests.rs new file mode 100644 index 00000000..d93507c4 --- /dev/null +++ b/crates/trusted-proxies/tests/integration/cloud_tests.rs @@ -0,0 +1,30 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; +use rustfs_trusted_proxies::cloud::detector::CloudDetector; +use rustfs_trusted_proxies::cloud::metadata::AwsMetadataFetcher; + +#[tokio::test] +async fn test_cloud_detector_disabled() { + let detector = CloudDetector::new(false, Duration::from_secs(1), None); + let provider = detector.detect_provider(); + assert!(provider.is_none()); +} + +#[tokio::test] +async fn test_aws_metadata_fetcher() { + let fetcher = AwsMetadataFetcher::new(); + assert_eq!(fetcher.provider_name(), "aws"); +} diff --git a/crates/trusted-proxies/test/integration/mod.rs b/crates/trusted-proxies/tests/integration/mod.rs similarity index 80% rename from crates/trusted-proxies/test/integration/mod.rs rename to crates/trusted-proxies/tests/integration/mod.rs index fa0bece3..2df0222c 100644 --- a/crates/trusted-proxies/test/integration/mod.rs +++ b/crates/trusted-proxies/tests/integration/mod.rs @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Integration tests for the trusted proxy system +//! Integration tests for the trusted proxy system. +#[cfg(test)] mod api_tests; +#[cfg(test)] mod cloud_tests; +#[cfg(test)] mod proxy_tests; - -// 重新导出测试模块 -pub use api_tests::*; -pub use cloud_tests::*; -pub use proxy_tests::*; diff --git a/crates/trusted-proxies/tests/integration/proxy_tests.rs b/crates/trusted-proxies/tests/integration/proxy_tests.rs new file mode 100644 index 00000000..15f8acf1 --- /dev/null +++ b/crates/trusted-proxies/tests/integration/proxy_tests.rs @@ -0,0 +1,39 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::body::Body; +use axum::{routing::get, Router}; +use tower::ServiceExt; +use rustfs_trusted_proxies::config::{TrustedProxy, TrustedProxyConfig, ValidationMode}; +use rustfs_trusted_proxies::middleware::TrustedProxyLayer; + +#[tokio::test] +async fn test_proxy_validation_flow() { + let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; + let config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]); + let proxy_layer = TrustedProxyLayer::enabled(config, None); + + let app = Router::new() + .route("/test", get(|| async { "OK" })) + .layer(proxy_layer); + + let request = axum::http::Request::builder() + .uri("/test") + .header("X-Forwarded-For", "203.0.113.195") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); +} diff --git a/crates/trusted-proxies/tests/unit/config_tests.rs b/crates/trusted-proxies/tests/unit/config_tests.rs new file mode 100644 index 00000000..7ac71574 --- /dev/null +++ b/crates/trusted-proxies/tests/unit/config_tests.rs @@ -0,0 +1,117 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::IpAddr; +use rustfs_trusted_proxies::config::env::{DEFAULT_TRUSTED_PROXIES, ENV_TRUSTED_PROXIES}; +use rustfs_trusted_proxies::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode}; + +#[test] +fn test_config_loader_default() { + std::env::remove_var(ENV_TRUSTED_PROXIES); + let config = ConfigLoader::from_env_or_default(); + assert_eq!(config.server_addr.port(), 3000); + assert!(!config.proxy.proxies.is_empty()); + assert_eq!(config.proxy.validation_mode, ValidationMode::HopByHop); + assert!(config.proxy.enable_rfc7239); + assert_eq!(config.proxy.max_hops, 10); +} + +#[test] +fn test_config_loader_env_vars() { + std::env::set_var(ENV_TRUSTED_PROXIES, "192.168.1.0/24,10.0.0.0/8"); + std::env::set_var("TRUSTED_PROXY_VALIDATION_MODE", "strict"); + std::env::set_var("TRUSTED_PROXY_MAX_HOPS", "5"); + std::env::set_var("SERVER_PORT", "8080"); + + let config = ConfigLoader::from_env(); + + if let Ok(config) = config { + assert_eq!(config.server_addr.port(), 8080); + assert_eq!(config.proxy.validation_mode, ValidationMode::Strict); + assert_eq!(config.proxy.max_hops, 5); + + std::env::remove_var(ENV_TRUSTED_PROXIES); + std::env::remove_var("TRUSTED_PROXY_VALIDATION_MODE"); + std::env::remove_var("TRUSTED_PROXY_MAX_HOPS"); + std::env::remove_var("SERVER_PORT"); + } else { + panic!("Failed to load configuration from environment variables"); + } +} + +#[test] +fn test_trusted_proxy_config() { + let proxies = vec![ + TrustedProxy::Single("192.168.1.1".parse().unwrap()), + TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), + ]; + + let config = TrustedProxyConfig::new(proxies.clone(), ValidationMode::Strict, true, 10, true, vec![]); + + assert_eq!(config.proxies.len(), 2); + assert_eq!(config.validation_mode, ValidationMode::Strict); + assert!(config.enable_rfc7239); + assert_eq!(config.max_hops, 10); + assert!(config.enable_chain_continuity_check); + + let test_ip: IpAddr = "192.168.1.1".parse().unwrap(); + let test_socket_addr = std::net::SocketAddr::new(test_ip, 8080); + assert!(config.is_trusted(&test_socket_addr)); + + let test_ip2: IpAddr = "10.0.1.1".parse().unwrap(); + let test_socket_addr2 = std::net::SocketAddr::new(test_ip2, 8080); + assert!(config.is_trusted(&test_socket_addr2)); +} + +#[test] +fn test_trusted_proxy_contains() { + let single_proxy = TrustedProxy::Single("192.168.1.1".parse().unwrap()); + let test_ip: IpAddr = "192.168.1.1".parse().unwrap(); + let test_ip2: IpAddr = "192.168.1.2".parse().unwrap(); + + assert!(single_proxy.contains(&test_ip)); + assert!(!single_proxy.contains(&test_ip2)); + + let cidr_proxy = TrustedProxy::Cidr("192.168.1.0/24".parse().unwrap()); + assert!(cidr_proxy.contains(&test_ip)); + assert!(cidr_proxy.contains(&test_ip2)); + + let test_ip3: IpAddr = "192.168.2.1".parse().unwrap(); + assert!(!cidr_proxy.contains(&test_ip3)); +} + +#[test] +fn test_private_network_check() { + let config = TrustedProxyConfig::new( + Vec::new(), + ValidationMode::Lenient, + true, + 10, + true, + vec!["10.0.0.0/8".parse().unwrap(), "192.168.0.0/16".parse().unwrap()], + ); + + let private_ip: IpAddr = "10.0.1.1".parse().unwrap(); + let private_ip2: IpAddr = "192.168.1.1".parse().unwrap(); + let public_ip: IpAddr = "8.8.8.8".parse().unwrap(); + + assert!(config.is_private_network(&private_ip)); + assert!(config.is_private_network(&private_ip2)); + assert!(!config.is_private_network(&public_ip)); +} + +#[test] +fn test_default_values() { + assert_eq!(DEFAULT_TRUSTED_PROXIES, "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"); +} diff --git a/crates/trusted-proxies/tests/unit/ip_tests.rs b/crates/trusted-proxies/tests/unit/ip_tests.rs new file mode 100644 index 00000000..455e98f6 --- /dev/null +++ b/crates/trusted-proxies/tests/unit/ip_tests.rs @@ -0,0 +1,194 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::IpAddr; +use std::str::FromStr; +use rustfs_trusted_proxies::utils::IpUtils; + +#[test] +fn test_is_valid_ip_address() { + let valid_ip: IpAddr = "192.168.1.1".parse().unwrap(); + assert!(IpUtils::is_valid_ip_address(&valid_ip)); + + let unspecified_ip: IpAddr = "0.0.0.0".parse().unwrap(); + assert!(!IpUtils::is_valid_ip_address(&unspecified_ip)); + + let multicast_ip: IpAddr = "224.0.0.1".parse().unwrap(); + assert!(!IpUtils::is_valid_ip_address(&multicast_ip)); + + let valid_ipv6: IpAddr = "2001:db8::1".parse().unwrap(); + assert!(IpUtils::is_valid_ip_address(&valid_ipv6)); + + let unspecified_ipv6: IpAddr = "::".parse().unwrap(); + assert!(!IpUtils::is_valid_ip_address(&unspecified_ipv6)); +} + +#[test] +fn test_is_reserved_ip() { + let private_ip: IpAddr = "10.0.0.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&private_ip)); + + let loopback_ip: IpAddr = "127.0.0.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&loopback_ip)); + + let link_local_ip: IpAddr = "169.254.0.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&link_local_ip)); + + let documentation_ip: IpAddr = "192.0.2.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&documentation_ip)); + + let public_ip: IpAddr = "8.8.8.8".parse().unwrap(); + assert!(!IpUtils::is_reserved_ip(&public_ip)); +} + +#[test] +fn test_is_private_ip() { + assert!(IpUtils::is_private_ip(&"10.0.0.1".parse().unwrap())); + assert!(IpUtils::is_private_ip(&"10.255.255.254".parse().unwrap())); + + assert!(IpUtils::is_private_ip(&"172.16.0.1".parse().unwrap())); + assert!(IpUtils::is_private_ip(&"172.31.255.254".parse().unwrap())); + assert!(!IpUtils::is_private_ip(&"172.15.0.1".parse().unwrap())); + assert!(!IpUtils::is_private_ip(&"172.32.0.1".parse().unwrap())); + + assert!(IpUtils::is_private_ip(&"192.168.0.1".parse().unwrap())); + assert!(IpUtils::is_private_ip(&"192.168.255.254".parse().unwrap())); + + assert!(!IpUtils::is_private_ip(&"8.8.8.8".parse().unwrap())); + assert!(!IpUtils::is_private_ip(&"203.0.113.1".parse().unwrap())); +} + +#[test] +fn test_is_loopback_ip() { + assert!(IpUtils::is_loopback_ip(&"127.0.0.1".parse().unwrap())); + assert!(IpUtils::is_loopback_ip(&"127.255.255.254".parse().unwrap())); + assert!(IpUtils::is_loopback_ip(&"::1".parse().unwrap())); + assert!(!IpUtils::is_loopback_ip(&"192.168.1.1".parse().unwrap())); + assert!(!IpUtils::is_loopback_ip(&"2001:db8::1".parse().unwrap())); +} + +#[test] +fn test_is_link_local_ip() { + assert!(IpUtils::is_link_local_ip(&"169.254.0.1".parse().unwrap())); + assert!(IpUtils::is_link_local_ip(&"169.254.255.254".parse().unwrap())); + assert!(IpUtils::is_link_local_ip(&"fe80::1".parse().unwrap())); + assert!(IpUtils::is_link_local_ip(&"fe80::abcd:1234:5678:9abc".parse().unwrap())); + assert!(!IpUtils::is_link_local_ip(&"192.168.1.1".parse().unwrap())); + assert!(!IpUtils::is_link_local_ip(&"2001:db8::1".parse().unwrap())); +} + +#[test] +fn test_is_documentation_ip() { + assert!(IpUtils::is_documentation_ip(&"192.0.2.1".parse().unwrap())); + assert!(IpUtils::is_documentation_ip(&"198.51.100.1".parse().unwrap())); + assert!(IpUtils::is_documentation_ip(&"203.0.113.1".parse().unwrap())); + assert!(IpUtils::is_documentation_ip(&"2001:db8::1".parse().unwrap())); + assert!(!IpUtils::is_documentation_ip(&"8.8.8.8".parse().unwrap())); + assert!(!IpUtils::is_documentation_ip(&"2001:4860::1".parse().unwrap())); +} + +#[test] +fn test_parse_ip_or_cidr() { + let result = IpUtils::parse_ip_or_cidr("192.168.1.1"); + assert!(result.is_ok()); + + let result = IpUtils::parse_ip_or_cidr("192.168.1.0/24"); + assert!(result.is_ok()); + + let result = IpUtils::parse_ip_or_cidr("2001:db8::1"); + assert!(result.is_ok()); + + let result = IpUtils::parse_ip_or_cidr("2001:db8::/32"); + assert!(result.is_ok()); + + let result = IpUtils::parse_ip_or_cidr("invalid"); + assert!(result.is_err()); +} + +#[test] +fn test_parse_ip_list() { + let result = IpUtils::parse_ip_list("192.168.1.1, 10.0.0.1, 8.8.8.8"); + assert!(result.is_ok()); + + let ips = result.unwrap(); + assert_eq!(ips.len(), 3); + assert_eq!(ips[0], IpAddr::from_str("192.168.1.1").unwrap()); + assert_eq!(ips[1], IpAddr::from_str("10.0.0.1").unwrap()); + assert_eq!(ips[2], IpAddr::from_str("8.8.8.8").unwrap()); + + let result = IpUtils::parse_ip_list("192.168.1.1,10.0.0.1"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 2); + + let result = IpUtils::parse_ip_list(""); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + + let result = IpUtils::parse_ip_list("192.168.1.1, invalid"); + assert!(result.is_err()); +} + +#[test] +fn test_parse_network_list() { + let result = IpUtils::parse_network_list("192.168.1.0/24, 10.0.0.0/8"); + assert!(result.is_ok()); + + let networks = result.unwrap(); + assert_eq!(networks.len(), 2); + + let result = IpUtils::parse_network_list("192.168.1.1"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 1); + + let result = IpUtils::parse_network_list("192.168.1.0/24, invalid"); + assert!(result.is_err()); +} + +#[test] +fn test_ip_in_networks() { + let networks = vec!["10.0.0.0/8".parse().unwrap(), "192.168.1.0/24".parse().unwrap()]; + + let ip_in_network: IpAddr = "10.0.1.1".parse().unwrap(); + let ip_in_network2: IpAddr = "192.168.1.100".parse().unwrap(); + let ip_not_in_network: IpAddr = "8.8.8.8".parse().unwrap(); + + assert!(IpUtils::ip_in_networks(&ip_in_network, &networks)); + assert!(IpUtils::ip_in_networks(&ip_in_network2, &networks)); + assert!(!IpUtils::ip_in_networks(&ip_not_in_network, &networks)); +} + +#[test] +fn test_get_ip_type() { + assert_eq!(IpUtils::get_ip_type(&"10.0.0.1".parse().unwrap()), "private"); + assert_eq!(IpUtils::get_ip_type(&"127.0.0.1".parse().unwrap()), "loopback"); + assert_eq!(IpUtils::get_ip_type(&"169.254.0.1".parse().unwrap()), "link_local"); + assert_eq!(IpUtils::get_ip_type(&"192.0.2.1".parse().unwrap()), "documentation"); + assert_eq!(IpUtils::get_ip_type(&"224.0.0.1".parse().unwrap()), "reserved"); + assert_eq!(IpUtils::get_ip_type(&"8.8.8.8".parse().unwrap()), "public"); +} + +#[test] +fn test_canonical_ip() { + let ipv4: IpAddr = "192.168.001.001".parse().unwrap(); + assert_eq!(IpUtils::canonical_ip(&ipv4), "192.168.1.1"); + + let ipv6_full: IpAddr = "2001:0db8:0000:0000:0000:0000:0000:0001".parse().unwrap(); + let ipv6_compressed: IpAddr = "2001:db8::1".parse().unwrap(); + + assert_eq!(IpUtils::canonical_ip(&ipv6_full), "2001:db8::1"); + assert_eq!(IpUtils::canonical_ip(&ipv6_compressed), "2001:db8::1"); + + let ipv6_multi_zero: IpAddr = "2001:0db8:0000:0000:abcd:0000:0000:1234".parse().unwrap(); + assert_eq!(IpUtils::canonical_ip(&ipv6_multi_zero), "2001:db8::abcd:0:0:1234"); +} diff --git a/crates/trusted-proxies/test/unit/mod.rs b/crates/trusted-proxies/tests/unit/mod.rs similarity index 79% rename from crates/trusted-proxies/test/unit/mod.rs rename to crates/trusted-proxies/tests/unit/mod.rs index 92a5b076..fb112a67 100644 --- a/crates/trusted-proxies/test/unit/mod.rs +++ b/crates/trusted-proxies/tests/unit/mod.rs @@ -12,15 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Unit tests for the trusted proxy system +//! Unit tests for the trusted proxy system components. +#[cfg(test)] mod config_tests; +#[cfg(test)] mod ip_tests; +#[cfg(test)] mod validation_tests; +#[cfg(test)] mod validator_tests; - -// 重新导出测试模块 -pub use config_tests::*; -pub use ip_tests::*; -pub use validation_tests::*; -pub use validator_tests::*; diff --git a/crates/trusted-proxies/tests/unit/validation_tests.rs b/crates/trusted-proxies/tests/unit/validation_tests.rs new file mode 100644 index 00000000..5014498c --- /dev/null +++ b/crates/trusted-proxies/tests/unit/validation_tests.rs @@ -0,0 +1,69 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use http::HeaderMap; +use std::net::IpAddr; +use rustfs_trusted_proxies::utils::{IpUtils, ValidationUtils}; + +#[test] +fn test_email_validation() { + assert!(ValidationUtils::is_valid_email("user@example.com")); + assert!(!ValidationUtils::is_valid_email("invalid-email")); +} + +#[test] +fn test_url_validation() { + assert!(ValidationUtils::is_valid_url("https://example.com")); + assert!(!ValidationUtils::is_valid_url("invalid")); +} + +#[test] +fn test_x_forwarded_for_validation() { + assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195")); + assert!(!ValidationUtils::validate_x_forwarded_for("invalid")); +} + +#[test] +fn test_forwarded_header_validation() { + assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60")); + assert!(!ValidationUtils::validate_forwarded_header("invalid")); +} + +#[test] +fn test_ip_in_range_validation() { + let cidr_ranges = vec![ + "10.0.0.0/8".to_string(), + "192.168.0.0/16".to_string(), + ]; + let ip: IpAddr = "10.0.1.1".parse().unwrap(); + assert!(ValidationUtils::validate_ip_in_range(&ip, &cidr_ranges)); +} + +#[test] +fn test_header_value_validation() { + assert!(ValidationUtils::validate_header_value("text/plain")); + assert!(!ValidationUtils::validate_header_value(&"a".repeat(8193))); +} + +#[test] +fn test_port_validation() { + assert!(ValidationUtils::validate_port(80)); + assert!(!ValidationUtils::validate_port(0)); +} + +#[test] +fn test_cidr_validation() { + assert!(ValidationUtils::validate_cidr("192.168.1.0/24")); + assert!(!ValidationUtils::validate_cidr("invalid")); +} diff --git a/crates/trusted-proxies/tests/unit/validator_tests.rs b/crates/trusted-proxies/tests/unit/validator_tests.rs new file mode 100644 index 00000000..45402d2b --- /dev/null +++ b/crates/trusted-proxies/tests/unit/validator_tests.rs @@ -0,0 +1,56 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use axum::http::HeaderMap; +use rustfs_trusted_proxies::config::{TrustedProxy, TrustedProxyConfig, ValidationMode}; +use rustfs_trusted_proxies::proxy::chain::ProxyChainAnalyzer; +use rustfs_trusted_proxies::proxy::validator::{ClientInfo, ProxyValidator}; + +fn create_test_config() -> TrustedProxyConfig { + let proxies = vec![ + TrustedProxy::Single("192.168.1.100".parse().unwrap()), + TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), + ]; + TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 5, true, vec![]) +} + +#[test] +fn test_client_info_direct() { + let addr = SocketAddr::new(IpAddr::from([192, 168, 1, 1]), 8080); + let client_info = ClientInfo::direct(addr); + assert_eq!(client_info.real_ip, IpAddr::from([192, 168, 1, 1])); +} + +#[test] +fn test_parse_x_forwarded_for() { + let header_value = "203.0.113.195, 198.51.100.1"; + let result = ProxyValidator::parse_x_forwarded_for(header_value); + assert_eq!(result.len(), 2); +} + +#[test] +fn test_proxy_chain_analyzer_hop_by_hop() { + let config = create_test_config(); + let analyzer = ProxyChainAnalyzer::new(config); + let chain = vec![ + IpAddr::from_str("203.0.113.195").unwrap(), + IpAddr::from_str("10.0.1.100").unwrap(), + ]; + let current_proxy = IpAddr::from_str("192.168.1.100").unwrap(); + let headers = HeaderMap::new(); + let result = analyzer.analyze_chain(&chain, current_proxy, &headers); + assert!(result.is_ok()); +} diff --git a/docs/special-characters-README_ZH.md b/docs/special-characters-README_ZH.md index a1d87e06..b21a5e73 100644 --- a/docs/special-characters-README_ZH.md +++ b/docs/special-characters-README_ZH.md @@ -1,6 +1,6 @@ # 对象路径中的特殊字符 - 完整文档 -本目录包含关于在 RustFS 中处理 S3 对象路径中特殊字符(空格、加号、百分号等)的完整文档。 +本目录包含关于在 RustFS 中处理 S3 对象路径中特殊字符 (空格、加号、百分号等) 的完整文档。 ## 快速链接 @@ -12,18 +12,18 @@ ### 问题现象 -用户报告了两个问题: -1. **问题 A**: UI 可以导航到包含特殊字符的文件夹,但无法列出其中的内容 +用户报告了两个问题: +1. **问题 A**: UI 可以导航到包含特殊字符的文件夹,但无法列出其中的内容 2. **问题 B**: 上传路径中包含 `+` 号的文件时出现 400 错误 ### 根本原因 -经过深入调查,包括检查 s3s 库的源代码,我们发现: +经过深入调查,包括检查 s3s 库的源代码,我们发现: **后端 (RustFS) 工作正常** ✅ - s3s 库正确地对 HTTP 请求中的对象键进行 URL 解码 - RustFS 正确存储和检索包含特殊字符的对象 -- 命令行工具(mc, aws-cli)完美工作 → 证明后端正确处理特殊字符 +- 命令行工具 (mc, aws-cli) 完美工作 → 证明后端正确处理特殊字符 **问题出在 UI/客户端层** ❌ - 某些客户端未正确进行 URL 编码 @@ -33,8 +33,8 @@ ### 解决方案 1. **用户**: 使用正规的 S3 SDK/客户端(它们会自动处理编码) -2. **开发者**: 后端无需修复,但添加了防御性验证和日志 -3. **UI**: UI 需要正确对所有请求进行 URL 编码(如适用) +2. **开发者**: 后端无需修复,但添加了防御性验证和日志 +3. **UI**: UI 需要正确对所有请求进行 URL 编码 (如适用) ## URL 编码快速参考 @@ -44,11 +44,11 @@ | 加号 | `+` | `%2B` | `%2B` | | 百分号 | `%` | `%25` | `%25` | -**重要**: 在 URL **路径**中,`+` = 字面加号(不是空格)。只有 `%20` = 空格! +**重要**: 在 URL **路径**中,`+` = 字面加号 (不是空格)。只有 `%20` = 空格! ## 快速示例 -### ✅ 正确使用(使用 mc) +### ✅ 正确使用 (使用 mc) ```bash # 上传 @@ -60,7 +60,7 @@ mc ls "myrustfs/bucket/路径 包含 空格/" # 结果: ✅ 成功 - mc 正确编码了请求 ``` -### ❌ 可能失败(原始 HTTP 未编码) +### ❌ 可能失败 (原始 HTTP 未编码) ```bash # 错误: 未编码 @@ -82,7 +82,7 @@ curl "http://localhost:9000/bucket/%E8%B7%AF%E5%BE%84%20%E5%8C%85%E5%90%AB%20%E7 ### ✅ 已完成 -1. **后端验证**: 添加了控制字符验证(拒绝空字节、换行符) +1. **后端验证**: 添加了控制字符验证 (拒绝空字节、换行符) 2. **调试日志**: 为包含特殊字符的键添加了日志记录 3. **测试**: 创建了综合 e2e 测试套件 4. **文档**: @@ -103,7 +103,7 @@ curl "http://localhost:9000/bucket/%E8%B7%AF%E5%BE%84%20%E5%8C%85%E5%90%AB%20%E7 3. **用户沟通**: - 更新用户文档 - 在 FAQ 中添加故障排除 - - 传达已知的 UI 限制(如有) + - 传达已知的 UI 限制 (如有) ## 测试 @@ -134,15 +134,15 @@ aws --endpoint-url=http://localhost:9000 s3 ls "s3://bucket/测试 包含 空格 ## 支持 -如果遇到特殊字符问题: +如果遇到特殊字符问题: 1. **首先**: 查看[客户端指南](./client-special-characters-guide.md) 2. **尝试**: 使用 mc 或 AWS CLI 隔离问题 -3. **启用**: 调试日志: `RUST_LOG=rustfs=debug` -4. **报告**: 创建问题,包含: +3. **启用**: 调试日志:`RUST_LOG=rustfs=debug` +4. **报告**: 创建问题,包含: - 使用的客户端/SDK - 导致问题的确切对象名称 - - mc 是否工作(以隔离后端与客户端) + - mc 是否工作 (以隔离后端与客户端) - 调试日志 ## 相关文档 @@ -154,26 +154,26 @@ aws --endpoint-url=http://localhost:9000 s3 ls "s3://bucket/测试 包含 空格 ## 常见问题 -**问: 可以在对象名称中使用空格吗?** -答: 可以,但请使用能自动处理编码的 S3 SDK。 +**问:可以在对象名称中使用空格吗?** +答:可以,但请使用能自动处理编码的 S3 SDK。 -**问: 为什么 `+` 不能用作空格?** -答: 在 URL 路径中,`+` 表示字面加号。只有在查询参数中 `+` 才表示空格。在路径中使用 `%20` 表示空格。 +**问:为什么 `+` 不能用作空格?** +答:在 URL 路径中,`+` 表示字面加号。只有在查询参数中 `+` 才表示空格。在路径中使用 `%20` 表示空格。 -**问: RustFS 支持对象名称中的 Unicode 吗?** -答: 支持,对象名称是 UTF-8 字符串。它们支持任何有效的 UTF-8 字符。 +**问:RustFS 支持对象名称中的 Unicode 吗?** +答:支持,对象名称是 UTF-8 字符串。它们支持任何有效的 UTF-8 字符。 -**问: 哪些字符是禁止的?** -答: 控制字符(空字节、换行符、回车符)被拒绝。所有可打印字符都是允许的。 +**问:哪些字符是禁止的?** +答:控制字符 (空字节、换行符、回车符) 被拒绝。所有可打印字符都是允许的。 -**问: 如何修复"UI 无法列出文件夹"的问题?** -答: 使用 CLI(mc 或 aws-cli)代替。这是 UI 错误,不是后端问题。 +**问:如何修复"UI 无法列出文件夹"的问题?** +答:使用 CLI(mc 或 aws-cli) 代替。这是 UI 错误,不是后端问题。 ## 版本历史 - **v1.0** (2025-12-09): 初始文档 - 完成综合分析 - - 确定根本原因(UI/客户端问题) + - 确定根本原因 (UI/客户端问题) - 添加后端验证和日志 - 创建客户端指南 - 添加 E2E 测试