diff --git a/Cargo.lock b/Cargo.lock index 43c6b1ba..e2db2e9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3254,6 +3254,12 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "doxygen-rs" version = "0.4.2" @@ -8387,10 +8393,15 @@ dependencies = [ "anyhow", "async-trait", "axum", + "chrono", + "dotenvy", "http 1.4.0", "ipnetwork", + "lazy_static", "metrics", "moka", + "parking_lot", + "regex", "reqwest", "rustfs-utils", "serde", @@ -8398,7 +8409,10 @@ dependencies = [ "thiserror 2.0.17", "tokio", "tower", + "tower-http", "tracing", + "tracing-subscriber", + "uuid", ] [[package]] diff --git a/crates/trusted-proxies/.env.example b/crates/trusted-proxies/.env.example new file mode 100644 index 00000000..9ec74d9d --- /dev/null +++ b/crates/trusted-proxies/.env.example @@ -0,0 +1,33 @@ +# Server Configuration +SERVER_HOST=0.0.0.0 +SERVER_PORT=3000 + +# Trusted Proxy Configuration +TRUSTED_PROXY_NETWORKS=127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8 +TRUSTED_PROXY_EXTRA_NETWORKS= +TRUSTED_PROXY_VALIDATION_MODE=hop_by_hop +TRUSTED_PROXY_ENABLE_RFC7239=true +TRUSTED_PROXY_MAX_HOPS=10 +TRUSTED_PROXY_CHAIN_CONTINUITY_CHECK=true +TRUSTED_PROXY_LOG_FAILED_VALIDATIONS=true + +# Cache Configuration +TRUSTED_PROXY_CACHE_CAPACITY=10000 +TRUSTED_PROXY_CACHE_TTL_SECONDS=300 +TRUSTED_PROXY_CACHE_CLEANUP_INTERVAL=60 + +# Monitoring Configuration +TRUSTED_PROXY_METRICS_ENABLED=true +TRUSTED_PROXY_LOG_LEVEL=info +TRUSTED_PROXY_STRUCTURED_LOGGING=false +TRUSTED_PROXY_TRACING_ENABLED=true + +# Cloud Integration +TRUSTED_PROXY_CLOUD_METADATA_ENABLED=false +TRUSTED_PROXY_CLOUD_METADATA_TIMEOUT=5 +TRUSTED_PROXY_CLOUDFLARE_IPS_ENABLED=false +TRUSTED_PROXY_CLOUD_PROVIDER_FORCE= + +# Application +RUST_LOG=info +RUST_BACKTRACE=1 \ No newline at end of file diff --git a/crates/trusted-proxies/Cargo.toml b/crates/trusted-proxies/Cargo.toml index 33e2a9d4..5fe880ff 100644 --- a/crates/trusted-proxies/Cargo.toml +++ b/crates/trusted-proxies/Cargo.toml @@ -28,7 +28,9 @@ categories = ["network-programming", "security", "web-programming"] anyhow = { workspace = true } async-trait = { workspace = true } axum = { workspace = true } +chrono = { workspace = true } http = { workspace = true } +tower-http = { workspace = true } ipnetwork = { workspace = true } metrics = { workspace = true } moka = { workspace = true } @@ -40,6 +42,12 @@ thiserror = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync", "time", "test-util"] } tower = { workspace = true } tracing = { workspace = true } +tracing-subscriber = { workspace = true } +parking_lot = { workspace = true } +uuid = { workspace = true, features = ["v4"] } +regex = { workspace = true } +lazy_static = { workspace = true } +dotenvy = "0.15.7" [lints] workspace = true diff --git a/crates/trusted-proxies/src/api/handlers.rs b/crates/trusted-proxies/src/api/handlers.rs index 6238cfff..ed0e34b2 100644 --- a/crates/trusted-proxies/src/api/handlers.rs +++ b/crates/trusted-proxies/src/api/handlers.rs @@ -11,3 +11,138 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! API request handlers + +use axum::{ + extract::{Request, State}, + http::StatusCode, + response::{IntoResponse, Json}, +}; +use serde_json::{json, Value}; + +use crate::error::AppError; +use crate::middleware::ClientInfo; +use crate::AppState; + +/// 健康检查端点 +pub async fn health_check() -> impl IntoResponse { + Json(json!({ + "status": "healthy", + "timestamp": chrono::Utc::now().to_rfc3339(), + "service": "trusted-proxy", + "version": env!("CARGO_PKG_VERSION"), + })) +} + +/// 显示配置信息 +pub async fn show_config(State(state): State) -> Result, AppError> { + let config = &state.config; + + let response = json!({ + "server": { + "addr": config.server_addr.to_string(), + }, + "proxy": { + "trusted_networks_count": config.proxy.proxies.len(), + "validation_mode": format!("{:?}", config.proxy.validation_mode), + "max_hops": config.proxy.max_hops, + "enable_rfc7239": config.proxy.enable_rfc7239, + }, + "cache": { + "capacity": config.cache.capacity, + "ttl_seconds": config.cache.ttl_seconds, + }, + "monitoring": { + "metrics_enabled": config.monitoring.metrics_enabled, + "log_level": config.monitoring.log_level, + }, + "cloud": { + "metadata_enabled": config.cloud.metadata_enabled, + "cloudflare_enabled": config.cloud.cloudflare_ips_enabled, + }, + }); + + Ok(Json(response)) +} + +/// 显示客户端信息 +pub async fn client_info(State(state): State, req: Request) -> impl IntoResponse { + // 从请求扩展中获取客户端信息 + let client_info = req.extensions().get::(); + + match client_info { + Some(info) => { + let response = json!({ + "client": { + "real_ip": info.real_ip.to_string(), + "is_from_trusted_proxy": info.is_from_trusted_proxy, + "proxy_hops": info.proxy_hops, + "validation_mode": format!("{:?}", info.validation_mode), + }, + "headers": { + "forwarded_host": info.forwarded_host, + "forwarded_proto": info.forwarded_proto, + }, + "warnings": info.warnings, + "timestamp": chrono::Utc::now().to_rfc3339(), + }); + + Json(response).into_response() + } + None => { + let response = json!({ + "error": "Client information not available", + "message": "The trusted proxy middleware may not be enabled or configured correctly", + }); + + (StatusCode::INTERNAL_SERVER_ERROR, Json(response)).into_response() + } + } +} + +/// 代理测试端点(用于测试代理头部) +pub async fn proxy_test(req: Request) -> Json { + // 收集所有代理相关的头部 + let headers: Vec<(String, String)> = req + .headers() + .iter() + .filter(|(name, _)| { + let name_str = name.as_str().to_lowercase(); + name_str.contains("forwarded") || name_str.contains("x-forwarded") || name_str.contains("x-real") + }) + .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("[INVALID]").to_string())) + .collect(); + + // 获取对端地址 + let peer_addr = req + .extensions() + .get::() + .map(|addr| addr.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + + Json(json!({ + "peer_addr": peer_addr, + "method": req.method().to_string(), + "uri": req.uri().to_string(), + "proxy_headers": headers, + "timestamp": chrono::Utc::now().to_rfc3339(), + })) +} + +/// 指标端点(Prometheus 格式) +pub async fn metrics(State(state): State) -> impl IntoResponse { + if !state.config.monitoring.metrics_enabled { + return (StatusCode::NOT_FOUND, "Metrics are not enabled".to_string()).into_response(); + } + + // 在实际应用中,这里应该返回 Prometheus 格式的指标 + // 这里返回简单的 JSON 作为示例 + let metrics = json!({ + "message": "Metrics endpoint", + "note": "In a real implementation, this would return Prometheus format metrics", + "status": "metrics_enabled", + }); + + Json(metrics).into_response() +} diff --git a/crates/trusted-proxies/src/cloud/detector.rs b/crates/trusted-proxies/src/cloud/detector.rs index 6238cfff..2b513d0a 100644 --- a/crates/trusted-proxies/src/cloud/detector.rs +++ b/crates/trusted-proxies/src/cloud/detector.rs @@ -11,3 +11,251 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Cloud provider detection and metadata fetching + +use async_trait::async_trait; +use std::time::Duration; +use tracing::{debug, info, warn}; + +use crate::error::AppError; + +/// 云服务商类型 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CloudProvider { + /// Amazon Web Services + Aws, + /// Microsoft Azure + Azure, + /// Google Cloud Platform + Gcp, + /// DigitalOcean + DigitalOcean, + /// Cloudflare + Cloudflare, + /// 未知或自定义 + Unknown(String), +} + +impl CloudProvider { + /// 从环境变量检测云服务商 + pub fn detect_from_env() -> Option { + // 检查 AWS 环境变量 + if std::env::var("AWS_EXECUTION_ENV").is_ok() + || std::env::var("AWS_REGION").is_ok() + || std::env::var("EC2_INSTANCE_ID").is_ok() + { + return Some(Self::Aws); + } + + // 检查 Azure 环境变量 + if std::env::var("WEBSITE_SITE_NAME").is_ok() + || std::env::var("WEBSITE_INSTANCE_ID").is_ok() + || std::env::var("APPSETTING_WEBSITE_SITE_NAME").is_ok() + { + return Some(Self::Azure); + } + + // 检查 GCP 环境变量 + if std::env::var("GCP_PROJECT").is_ok() + || std::env::var("GOOGLE_CLOUD_PROJECT").is_ok() + || std::env::var("GAE_INSTANCE").is_ok() + { + return Some(Self::Gcp); + } + + // 检查 DigitalOcean 环境变量 + if std::env::var("DIGITALOCEAN_REGION").is_ok() { + return Some(Self::DigitalOcean); + } + + // 检查 Cloudflare 环境变量 + if std::env::var("CF_PAGES").is_ok() || std::env::var("CF_WORKERS").is_ok() { + return Some(Self::Cloudflare); + } + + None + } + + /// 获取云服务商名称 + pub fn name(&self) -> &str { + match self { + Self::Aws => "aws", + Self::Azure => "azure", + Self::Gcp => "gcp", + Self::DigitalOcean => "digitalocean", + Self::Cloudflare => "cloudflare", + Self::Unknown(name) => name, + } + } + + /// 从字符串解析云服务商 + pub fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "aws" | "amazon" => Self::Aws, + "azure" | "microsoft" => Self::Azure, + "gcp" | "google" => Self::Gcp, + "digitalocean" | "do" => Self::DigitalOcean, + "cloudflare" | "cf" => Self::Cloudflare, + _ => Self::Unknown(s.to_string()), + } + } +} + +/// 云元数据获取器特征 +#[async_trait] +pub trait CloudMetadataFetcher: Send + Sync { + /// 获取云服务商名称 + fn provider_name(&self) -> &str; + + /// 获取实例所在的网络 CIDR 范围 + async fn fetch_network_cidrs(&self) -> Result, AppError>; + + /// 获取云服务商的公共 IP 范围 + async fn fetch_public_ip_ranges(&self) -> Result, AppError>; + + /// 获取可信代理的 IP 范围 + async fn fetch_trusted_proxy_ranges(&self) -> Result, AppError> { + let mut ranges = Vec::new(); + + // 尝试获取网络 CIDR + match self.fetch_network_cidrs().await { + Ok(cidrs) => ranges.extend(cidrs), + Err(e) => warn!("Failed to fetch network CIDRs from {}: {}", self.provider_name(), e), + } + + // 尝试获取公共 IP 范围 + match self.fetch_public_ip_ranges().await { + Ok(public_ranges) => ranges.extend(public_ranges), + Err(e) => warn!("Failed to fetch public IP ranges from {}: {}", self.provider_name(), e), + } + + Ok(ranges) + } +} + +/// 云服务检测器 +#[derive(Debug, Clone)] +pub struct CloudDetector { + /// 是否启用云检测 + enabled: bool, + /// 超时时间 + timeout: Duration, + /// 强制指定的云服务商 + forced_provider: Option, +} + +impl CloudDetector { + /// 创建新的云检测器 + pub fn new(enabled: bool, timeout: Duration, forced_provider: Option) -> Self { + let forced_provider = forced_provider.map(|s| CloudProvider::from_str(&s)); + + Self { + enabled, + timeout, + forced_provider, + } + } + + /// 检测云服务商 + pub fn detect_provider(&self) -> Option { + if !self.enabled { + return None; + } + + // 如果强制指定了云服务商,直接返回 + if let Some(provider) = self.forced_provider { + return Some(provider); + } + + // 自动检测 + CloudProvider::detect_from_env() + } + + /// 获取可信代理 IP 范围 + pub async fn fetch_trusted_ranges(&self) -> Result, AppError> { + if !self.enabled { + debug!("Cloud metadata fetching is disabled"); + return Ok(Vec::new()); + } + + let provider = self.detect_provider(); + + match provider { + Some(CloudProvider::Aws) => { + info!("Detected AWS environment, fetching metadata"); + let fetcher = crate::cloud::metadata::AwsMetadataFetcher::new(); + fetcher.fetch_trusted_proxy_ranges().await + } + Some(CloudProvider::Azure) => { + info!("Detected Azure environment, fetching metadata"); + let fetcher = crate::cloud::metadata::AzureMetadataFetcher::new(); + fetcher.fetch_trusted_proxy_ranges().await + } + Some(CloudProvider::Gcp) => { + info!("Detected GCP environment, fetching metadata"); + let fetcher = crate::cloud::metadata::GcpMetadataFetcher::new(); + fetcher.fetch_trusted_proxy_ranges().await + } + Some(CloudProvider::Cloudflare) => { + info!("Detected Cloudflare environment"); + let ranges = crate::cloud::ranges::CloudflareIpRanges::fetch().await?; + Ok(ranges) + } + Some(CloudProvider::DigitalOcean) => { + info!("Detected DigitalOcean environment"); + let ranges = crate::cloud::ranges::DigitalOceanIpRanges::fetch().await?; + Ok(ranges) + } + Some(CloudProvider::Unknown(name)) => { + warn!("Unknown cloud provider detected: {}", name); + Ok(Vec::new()) + } + None => { + debug!("No cloud provider detected"); + Ok(Vec::new()) + } + } + } + + /// 尝试所有云服务商获取元数据 + pub async fn try_all_providers(&self) -> Result, AppError> { + if !self.enabled { + return Ok(Vec::new()); + } + + let providers: Vec> = vec![ + Box::new(crate::cloud::metadata::AwsMetadataFetcher::new()), + Box::new(crate::cloud::metadata::AzureMetadataFetcher::new()), + Box::new(crate::cloud::metadata::GcpMetadataFetcher::new()), + ]; + + for provider in providers { + let provider_name = provider.provider_name(); + debug!("Trying to fetch metadata from {}", provider_name); + + match provider.fetch_trusted_proxy_ranges().await { + Ok(ranges) => { + if !ranges.is_empty() { + info!("Fetched {} IP ranges from {}", ranges.len(), provider_name); + return Ok(ranges); + } + } + Err(e) => { + debug!("Failed to fetch metadata from {}: {}", provider_name, e); + } + } + } + + Ok(Vec::new()) + } +} + +/// 默认云检测器 +pub fn default_cloud_detector() -> CloudDetector { + CloudDetector::new( + false, // 默认禁用 + Duration::from_secs(5), + None, + ) +} diff --git a/crates/trusted-proxies/src/cloud/metadata/aws.rs b/crates/trusted-proxies/src/cloud/metadata/aws.rs index 6238cfff..738d760d 100644 --- a/crates/trusted-proxies/src/cloud/metadata/aws.rs +++ b/crates/trusted-proxies/src/cloud/metadata/aws.rs @@ -11,3 +11,143 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! AWS metadata fetching implementation + +use async_trait::async_trait; +use reqwest::Client; +use std::str::FromStr; +use std::time::Duration; +use tracing::{debug, info}; + +use crate::cloud::detector::CloudMetadataFetcher; +use crate::error::AppError; + +/// AWS 元数据获取器 +#[derive(Debug, Clone)] +pub struct AwsMetadataFetcher { + client: Client, + metadata_endpoint: String, +} + +impl AwsMetadataFetcher { + /// 创建新的 AWS 元数据获取器 + pub fn new() -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(2)) + .build() + .unwrap_or_else(|_| Client::new()); + + Self { + client, + metadata_endpoint: "http://169.254.169.254".to_string(), + } + } + + /// 获取 IMDSv2 令牌 + async fn get_metadata_token(&self) -> Result { + let url = format!("{}/latest/api/token", self.metadata_endpoint); + + match self + .client + .put(&url) + .header("X-aws-ec2-metadata-token-ttl-seconds", "21600") + .send() + .await + { + Ok(response) => { + if response.status().is_success() { + let token = response + .text() + .await + .map_err(|e| AppError::cloud(format!("Failed to read token: {}", e)))?; + Ok(token) + } else { + debug!("IMDSv2 token request failed with status: {}", response.status()); + Err(AppError::cloud("Failed to get IMDSv2 token".to_string())) + } + } + Err(e) => { + debug!("IMDSv2 token request failed: {}", e); + Err(AppError::cloud(format!("IMDSv2 request failed: {}", e))) + } + } + } +} + +#[async_trait] +impl CloudMetadataFetcher for AwsMetadataFetcher { + fn provider_name(&self) -> &str { + "aws" + } + + async fn fetch_network_cidrs(&self) -> Result, AppError> { + // 简化实现:返回常见的 AWS VPC 范围 + let default_ranges = vec![ + "10.0.0.0/8", // 大型 VPC + "172.16.0.0/12", // 中型 VPC + "192.168.0.0/16", // 小型 VPC + ]; + + let networks: Result, _> = default_ranges + .into_iter() + .map(|s| ipnetwork::IpNetwork::from_str(s)) + .collect(); + + match networks { + Ok(networks) => { + debug!("Using default AWS network ranges"); + Ok(networks) + } + Err(e) => Err(AppError::cloud(format!("Failed to parse default ranges: {}", e))), + } + } + + async fn fetch_public_ip_ranges(&self) -> Result, AppError> { + let url = "https://ip-ranges.amazonaws.com/ip-ranges.json"; + + #[derive(Debug, serde::Deserialize)] + struct AwsIpRanges { + prefixes: Vec, + } + + #[derive(Debug, serde::Deserialize)] + struct AwsPrefix { + ip_prefix: String, + region: String, + service: String, + } + + match self.client.get(url).timeout(Duration::from_secs(5)).send().await { + Ok(response) => { + if response.status().is_success() { + let ip_ranges: AwsIpRanges = response + .json() + .await + .map_err(|e| AppError::cloud(format!("Failed to parse AWS IP ranges: {}", e)))?; + + let mut networks = Vec::new(); + + for prefix in ip_ranges.prefixes { + // 只包含 EC2 和 CloudFront 的 IP 范围 + if prefix.service == "EC2" || prefix.service == "CLOUDFRONT" { + if let Ok(network) = ipnetwork::IpNetwork::from_str(&prefix.ip_prefix) { + networks.push(network); + } + } + } + + info!("Fetched {} AWS public IP ranges", networks.len()); + Ok(networks) + } else { + debug!("Failed to fetch AWS IP ranges: {}", response.status()); + Ok(Vec::new()) + } + } + Err(e) => { + debug!("Failed to fetch AWS IP ranges: {}", e); + Ok(Vec::new()) + } + } + } +} diff --git a/crates/trusted-proxies/src/cloud/metadata/azure.rs b/crates/trusted-proxies/src/cloud/metadata/azure.rs index 6238cfff..0736290c 100644 --- a/crates/trusted-proxies/src/cloud/metadata/azure.rs +++ b/crates/trusted-proxies/src/cloud/metadata/azure.rs @@ -11,3 +11,307 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Azure Cloud metadata fetching implementation + +use async_trait::async_trait; +use reqwest::Client; +use serde::Deserialize; +use std::time::Duration; +use tracing::{debug, info, warn}; + +use crate::cloud::detector::CloudMetadataFetcher; +use crate::error::AppError; + +/// Azure 元数据获取器 +#[derive(Debug, Clone)] +pub struct AzureMetadataFetcher { + client: Client, + metadata_endpoint: String, +} + +impl AzureMetadataFetcher { + /// 创建新的 Azure 元数据获取器 + pub fn new() -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(2)) + .build() + .unwrap_or_else(|_| Client::new()); + + Self { + client, + metadata_endpoint: "http://169.254.169.254".to_string(), + } + } + + /// 获取 Azure 元数据 + async fn get_metadata(&self, path: &str) -> Result { + let url = format!("{}/metadata/{}?api-version=2021-05-01", self.metadata_endpoint, path); + + debug!("Fetching Azure metadata from: {}", url); + + match self.client.get(&url).header("Metadata", "true").send().await { + Ok(response) => { + if response.status().is_success() { + let text = response + .text() + .await + .map_err(|e| AppError::cloud(format!("Failed to read response: {}", e)))?; + Ok(text) + } else { + debug!("Azure metadata request failed with status: {}", response.status()); + Err(AppError::cloud(format!("Azure metadata API returned status: {}", response.status()))) + } + } + Err(e) => { + debug!("Azure metadata request failed: {}", e); + Err(AppError::cloud(format!("Azure metadata request failed: {}", e))) + } + } + } + + /// 从 Microsoft 下载 IP 范围 + async fn fetch_azure_ip_ranges(&self) -> Result, AppError> { + // Azure 官方 IP 范围下载 URL + let url = + "https://download.microsoft.com/download/7/1/D/71D86715-5596-4529-9B13-DA13A5DE5B63/ServiceTags_Public_20231211.json"; + + #[derive(Debug, Deserialize)] + struct AzureServiceTags { + values: Vec, + } + + #[derive(Debug, Deserialize)] + struct AzureServiceTag { + id: String, + name: String, + properties: AzureServiceTagProperties, + } + + #[derive(Debug, Deserialize)] + struct AzureServiceTagProperties { + address_prefixes: Vec, + region: Option, + system_service: Option, + } + + debug!("Fetching Azure IP ranges from: {}", url); + + match self.client.get(url).timeout(Duration::from_secs(10)).send().await { + Ok(response) => { + if response.status().is_success() { + let service_tags: AzureServiceTags = response + .json() + .await + .map_err(|e| AppError::cloud(format!("Failed to parse Azure IP ranges: {}", e)))?; + + let mut networks = Vec::new(); + + for tag in service_tags.values { + // 只包含 Azure 数据中心和前端服务的 IP 范围 + if tag.name.contains("Azure") && !tag.name.contains("ActiveDirectory") { + for prefix in tag.properties.address_prefixes { + if let Ok(network) = ipnetwork::IpNetwork::from_str(&prefix) { + networks.push(network); + } + } + } + } + + info!("Fetched {} Azure public IP ranges", networks.len()); + Ok(networks) + } else { + debug!("Failed to fetch Azure IP ranges: {}", response.status()); + Ok(Vec::new()) + } + } + Err(e) => { + debug!("Failed to fetch Azure IP ranges: {}", e); + // 如果 API 失败,返回默认的 Azure IP 范围 + Self::default_azure_ranges() + } + } + } + + /// 默认 Azure IP 范围(作为备选) + fn default_azure_ranges() -> Result, AppError> { + let ranges = vec![ + // Azure 全球 IP 范围 + "13.64.0.0/11", + "13.96.0.0/13", + "13.104.0.0/14", + "20.33.0.0/16", + "20.34.0.0/15", + "20.36.0.0/14", + "20.40.0.0/13", + "20.48.0.0/12", + "20.64.0.0/10", + "20.128.0.0/16", + "20.135.0.0/16", + "20.136.0.0/13", + "20.150.0.0/15", + "20.157.0.0/16", + "20.184.0.0/13", + "20.190.0.0/16", + "20.192.0.0/10", + "40.64.0.0/10", + "40.80.0.0/12", + "40.96.0.0/13", + "40.112.0.0/13", + "40.120.0.0/14", + "40.124.0.0/16", + "40.125.0.0/17", + "51.12.0.0/15", + "51.104.0.0/15", + "51.120.0.0/16", + "51.124.0.0/16", + "51.132.0.0/16", + "51.136.0.0/15", + "51.138.0.0/16", + "51.140.0.0/14", + "51.144.0.0/15", + "52.96.0.0/12", + "52.112.0.0/14", + "52.120.0.0/14", + "52.124.0.0/16", + "52.125.0.0/16", + "52.126.0.0/15", + "52.130.0.0/15", + "52.136.0.0/13", + "52.144.0.0/15", + "52.146.0.0/15", + "52.148.0.0/14", + "52.152.0.0/13", + "52.160.0.0/12", + "52.176.0.0/13", + "52.184.0.0/14", + "52.188.0.0/14", + "52.224.0.0/11", + "65.52.0.0/14", + "104.40.0.0/13", + "104.208.0.0/13", + "104.215.0.0/16", + "137.116.0.0/15", + "137.135.0.0/16", + "138.91.0.0/16", + "157.56.0.0/16", + "168.61.0.0/16", + "168.62.0.0/15", + "191.233.0.0/18", + "193.149.0.0/19", + // IPv6 范围 + "2603:1000::/40", + "2603:1010::/40", + "2603:1020::/40", + "2603:1030::/40", + "2603:1040::/40", + "2603:1050::/40", + "2603:1060::/40", + "2603:1070::/40", + "2603:1080::/40", + "2603:1090::/40", + "2603:10a0::/40", + "2603:10b0::/40", + "2603:10c0::/40", + "2603:10d0::/40", + "2603:10e0::/40", + "2603:10f0::/40", + "2603:1100::/40", + ]; + + let networks: Result, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect(); + + match networks { + Ok(networks) => { + debug!("Using default Azure IP ranges"); + Ok(networks) + } + Err(e) => Err(AppError::cloud(format!("Failed to parse default Azure ranges: {}", e))), + } + } +} + +#[async_trait] +impl CloudMetadataFetcher for AzureMetadataFetcher { + fn provider_name(&self) -> &str { + "azure" + } + + async fn fetch_network_cidrs(&self) -> Result, AppError> { + // 尝试从 Azure 元数据获取网络信息 + match self.get_metadata("instance/network/interface").await { + Ok(metadata) => { + #[derive(Debug, Deserialize)] + struct AzureNetworkInterface { + ipv4: AzureIpv4Info, + } + + #[derive(Debug, Deserialize)] + struct AzureIpv4Info { + subnet: Vec, + } + + #[derive(Debug, Deserialize)] + struct AzureSubnet { + address: String, + prefix: String, + } + + let interfaces: Vec = serde_json::from_str(&metadata) + .map_err(|e| AppError::cloud(format!("Failed to parse Azure network metadata: {}", e)))?; + + let mut cidrs = Vec::new(); + for interface in interfaces { + for subnet in interface.ipv4.subnet { + let cidr = format!("{}/{}", subnet.address, subnet.prefix); + if let Ok(network) = ipnetwork::IpNetwork::from_str(&cidr) { + cidrs.push(network); + } + } + } + + if !cidrs.is_empty() { + info!("Fetched {} network CIDRs from Azure metadata", cidrs.len()); + Ok(cidrs) + } else { + // 如果元数据中没有网络信息,使用默认的 Azure VNet 范围 + debug!("No network CIDRs found in Azure metadata, using defaults"); + Self::default_azure_network_ranges() + } + } + Err(e) => { + warn!("Failed to fetch Azure network metadata: {}", e); + // 元数据获取失败,使用默认范围 + Self::default_azure_network_ranges() + } + } + } + + async fn fetch_public_ip_ranges(&self) -> Result, AppError> { + self.fetch_azure_ip_ranges().await + } +} + +impl AzureMetadataFetcher { + /// 默认 Azure 网络范围 + fn default_azure_network_ranges() -> Result, AppError> { + // Azure 虚拟网络的常见 IP 范围 + let ranges = vec![ + "10.0.0.0/8", // 大型虚拟网络 + "172.16.0.0/12", // 中型虚拟网络 + "192.168.0.0/16", // 小型虚拟网络 + "100.64.0.0/10", // Azure 保留范围 + "192.0.0.0/24", // Azure 保留 + ]; + + let networks: Result, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect(); + + match networks { + Ok(networks) => { + debug!("Using default Azure network ranges"); + Ok(networks) + } + Err(e) => Err(AppError::cloud(format!("Failed to parse default network ranges: {}", e))), + } + } +} diff --git a/crates/trusted-proxies/src/cloud/metadata/gcp.rs b/crates/trusted-proxies/src/cloud/metadata/gcp.rs index 6238cfff..8a3eabbd 100644 --- a/crates/trusted-proxies/src/cloud/metadata/gcp.rs +++ b/crates/trusted-proxies/src/cloud/metadata/gcp.rs @@ -11,3 +11,370 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Google Cloud Platform (GCP) metadata fetching implementation + +use async_trait::async_trait; +use reqwest::Client; +use serde::Deserialize; +use std::str::FromStr; +use std::time::Duration; +use tracing::{debug, info, warn}; + +use crate::cloud::detector::CloudMetadataFetcher; +use crate::error::AppError; + +/// GCP 元数据获取器 +#[derive(Debug, Clone)] +pub struct GcpMetadataFetcher { + client: Client, + metadata_endpoint: String, +} + +impl GcpMetadataFetcher { + /// 创建新的 GCP 元数据获取器 + pub fn new() -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(2)) + .build() + .unwrap_or_else(|_| Client::new()); + + Self { + client, + metadata_endpoint: "http://metadata.google.internal".to_string(), + } + } + + /// 获取 GCP 元数据 + async fn get_metadata(&self, path: &str) -> Result { + let url = format!("{}/computeMetadata/v1/{}", self.metadata_endpoint, path); + + debug!("Fetching GCP metadata from: {}", url); + + match self.client.get(&url).header("Metadata-Flavor", "Google").send().await { + Ok(response) => { + if response.status().is_success() { + let text = response + .text() + .await + .map_err(|e| AppError::cloud(format!("Failed to read response: {}", e)))?; + Ok(text) + } else { + debug!("GCP metadata request failed with status: {}", response.status()); + Err(AppError::cloud(format!("GCP metadata API returned status: {}", response.status()))) + } + } + Err(e) => { + debug!("GCP metadata request failed: {}", e); + Err(AppError::cloud(format!("GCP metadata request failed: {}", e))) + } + } + } + + /// 获取网络掩码的前缀长度 + fn subnet_mask_to_prefix_length(mask: &str) -> Result { + let parts: Vec<&str> = mask.split('.').collect(); + if parts.len() != 4 { + return Err(AppError::cloud(format!("Invalid subnet mask: {}", mask))); + } + + let mut prefix_length = 0; + for part in parts { + let octet: u8 = part + .parse() + .map_err(|_| AppError::cloud(format!("Invalid octet in subnet mask: {}", part)))?; + + let mut remaining = octet; + while remaining > 0 { + if remaining & 0x80 == 0x80 { + prefix_length += 1; + remaining <<= 1; + } else { + break; + } + } + + if remaining != 0 { + return Err(AppError::cloud("Non-contiguous subnet mask".to_string())); + } + } + + Ok(prefix_length) + } +} + +#[async_trait] +impl CloudMetadataFetcher for GcpMetadataFetcher { + fn provider_name(&self) -> &str { + "gcp" + } + + async fn fetch_network_cidrs(&self) -> Result, AppError> { + // 获取网络接口列表 + match self.get_metadata("instance/network-interfaces/").await { + Ok(interfaces_metadata) => { + // 解析网络接口索引 + let interface_indices: Vec = interfaces_metadata + .lines() + .filter_map(|line| { + let line = line.trim().trim_end_matches('/'); + if line.chars().all(|c| c.is_ascii_digit()) { + line.parse().ok() + } else { + None + } + }) + .collect(); + + if interface_indices.is_empty() { + warn!("No network interfaces found in GCP metadata"); + return Self::default_gcp_network_ranges(); + } + + let mut cidrs = Vec::new(); + + for index in interface_indices { + // 获取子网信息 + let subnet_path = format!("instance/network-interfaces/{}/subnetworks", index); + + if let Ok(subnet_metadata) = self.get_metadata(&subnet_path).await { + // 子网元数据可能包含多个子网,取第一个 + if let Some(first_subnet) = subnet_metadata.lines().next() { + let subnet = first_subnet.trim(); + if !subnet.is_empty() { + // 尝试从子网名称提取网络信息 + if let Some(network) = Self::extract_network_from_subnet_name(subnet) { + cidrs.push(network); + continue; + } + } + } + } + + // 备选方案:使用 IP 地址和子网掩码 + let ip_path = format!("instance/network-interfaces/{}/ip", index); + let mask_path = format!("instance/network-interfaces/{}/subnetmask", index); + + match tokio::try_join!(self.get_metadata(&ip_path), self.get_metadata(&mask_path)) { + Ok((ip, mask)) => { + let ip = ip.trim(); + let mask = mask.trim(); + + if let (Ok(ip_addr), Ok(prefix_len)) = + (std::net::Ipv4Addr::from_str(ip), Self::subnet_mask_to_prefix_length(mask)) + { + let cidr_str = format!("{}/{}", ip_addr, prefix_len); + if let Ok(network) = ipnetwork::IpNetwork::from_str(&cidr_str) { + cidrs.push(network); + } + } + } + Err(e) => { + debug!("Failed to get IP/mask for interface {}: {}", index, e); + } + } + } + + if cidrs.is_empty() { + warn!("Could not determine network CIDRs from GCP metadata"); + Self::default_gcp_network_ranges() + } else { + info!("Fetched {} network CIDRs from GCP metadata", cidrs.len()); + Ok(cidrs) + } + } + Err(e) => { + warn!("Failed to fetch GCP network metadata: {}", e); + Self::default_gcp_network_ranges() + } + } + } + + async fn fetch_public_ip_ranges(&self) -> Result, AppError> { + self.fetch_gcp_ip_ranges().await + } +} + +impl GcpMetadataFetcher { + /// 从 Google API 获取 IP 范围 + async fn fetch_gcp_ip_ranges(&self) -> Result, AppError> { + let url = "https://www.gstatic.com/ipranges/cloud.json"; + + #[derive(Debug, Deserialize)] + struct GcpIpRanges { + prefixes: Vec, + } + + #[derive(Debug, Deserialize)] + struct GcpPrefix { + ipv4_prefix: Option, + ipv6_prefix: Option, + } + + debug!("Fetching GCP IP ranges from: {}", url); + + match self.client.get(url).timeout(Duration::from_secs(10)).send().await { + Ok(response) => { + if response.status().is_success() { + let ip_ranges: GcpIpRanges = response + .json() + .await + .map_err(|e| AppError::cloud(format!("Failed to parse GCP IP ranges: {}", e)))?; + + let mut networks = Vec::new(); + + for prefix in ip_ranges.prefixes { + if let Some(ipv4_prefix) = prefix.ipv4_prefix { + if let Ok(network) = ipnetwork::IpNetwork::from_str(&ipv4_prefix) { + networks.push(network); + } + } + } + + info!("Fetched {} GCP public IP ranges", networks.len()); + Ok(networks) + } else { + debug!("Failed to fetch GCP IP ranges: {}", response.status()); + Self::default_gcp_ip_ranges() + } + } + Err(e) => { + debug!("Failed to fetch GCP IP ranges: {}", e); + Self::default_gcp_ip_ranges() + } + } + } + + /// 默认 GCP IP 范围(作为备选) + fn default_gcp_ip_ranges() -> Result, AppError> { + let ranges = vec![ + // GCP 全球 IP 范围 + "8.34.208.0/20", + "8.35.192.0/20", + "8.35.208.0/20", + "23.236.48.0/20", + "23.251.128.0/19", + "34.0.0.0/15", + "34.2.0.0/16", + "34.3.0.0/23", + "34.3.3.0/24", + "34.3.4.0/24", + "34.3.8.0/21", + "34.3.16.0/20", + "34.3.32.0/19", + "34.3.64.0/18", + "34.3.128.0/17", + "34.4.0.0/14", + "34.8.0.0/13", + "34.16.0.0/12", + "34.32.0.0/11", + "34.64.0.0/10", + "34.128.0.0/10", + "35.184.0.0/13", + "35.192.0.0/14", + "35.196.0.0/15", + "35.198.0.0/16", + "35.199.0.0/17", + "35.199.128.0/18", + "35.200.0.0/13", + "35.208.0.0/12", + "35.224.0.0/12", + "35.240.0.0/13", + "104.154.0.0/15", + "104.196.0.0/14", + "107.167.160.0/19", + "107.178.192.0/18", + "108.59.80.0/20", + "108.170.192.0/18", + "108.177.0.0/17", + "130.211.0.0/16", + "136.112.0.0/12", + "142.250.0.0/15", + "146.148.0.0/17", + "162.216.148.0/22", + "162.222.176.0/21", + "172.217.0.0/16", + "172.253.0.0/16", + "173.194.0.0/16", + "192.158.28.0/22", + "192.178.0.0/15", + "193.186.4.0/24", + "199.36.154.0/23", + "199.36.156.0/24", + "199.192.112.0/22", + "199.223.232.0/21", + "207.223.160.0/20", + "208.65.152.0/22", + "208.68.108.0/22", + "208.81.188.0/22", + "208.117.224.0/19", + "209.85.128.0/17", + "216.58.192.0/19", + "216.73.80.0/20", + "216.239.32.0/19", + // IPv6 范围 + "2001:4860::/32", + "2404:6800::/32", + "2600:1900::/28", + "2607:f8b0::/32", + "2620:15c::/36", + "2800:3f0::/32", + "2a00:1450::/32", + "2c0f:fb50::/32", + ]; + + let networks: Result, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect(); + + match networks { + Ok(networks) => { + debug!("Using default GCP IP ranges"); + Ok(networks) + } + Err(e) => Err(AppError::cloud(format!("Failed to parse default GCP ranges: {}", e))), + } + } + + /// 默认 GCP 网络范围 + fn default_gcp_network_ranges() -> Result, AppError> { + // GCP VPC 网络的常见 IP 范围 + let ranges = vec![ + "10.0.0.0/8", // 大型 VPC 网络 + "172.16.0.0/12", // 中型 VPC 网络 + "192.168.0.0/16", // 小型 VPC 网络 + "100.64.0.0/10", // GCP 保留范围 + ]; + + let networks: Result, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect(); + + match networks { + Ok(networks) => { + debug!("Using default GCP network ranges"); + Ok(networks) + } + Err(e) => Err(AppError::cloud(format!("Failed to parse default GCP network ranges: {}", e))), + } + } + + /// 从子网名称提取网络信息 + fn extract_network_from_subnet_name(subnet_name: &str) -> Option { + // GCP 子网名称格式通常为:regions/{region}/subnetworks/{subnet-name} + // 或者 projects/{project}/regions/{region}/subnetworks/{subnet-name} + + // 尝试从子网名称中提取 IP 范围 + // 这只是一个简化的实现,实际可能需要查询 GCP API + + // 常见的 GCP 子网 IP 范围模式 + let patterns = [("10.", 8), ("172.16.", 12), ("192.168.", 16)]; + + for (prefix, prefix_len) in patterns { + if subnet_name.contains(&format!("subnet-{}", prefix.replace(".", "-"))) { + let cidr = format!("{}{}", prefix, "0.0.0/".to_string() + &prefix_len.to_string()); + if let Ok(network) = ipnetwork::IpNetwork::from_str(&cidr) { + return Some(network); + } + } + } + + None + } +} diff --git a/crates/trusted-proxies/src/cloud/metadata/mod.rs b/crates/trusted-proxies/src/cloud/metadata/mod.rs index b610d161..31a4fafd 100644 --- a/crates/trusted-proxies/src/cloud/metadata/mod.rs +++ b/crates/trusted-proxies/src/cloud/metadata/mod.rs @@ -12,6 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Cloud provider metadata fetching +//! +//! This module contains implementations for fetching metadata +//! from various cloud providers. + mod aws; mod azure; mod gcp; + +pub use aws::*; +pub use azure::*; +pub use gcp::*; diff --git a/crates/trusted-proxies/src/cloud/mod.rs b/crates/trusted-proxies/src/cloud/mod.rs index 3b9b9c40..338005ba 100644 --- a/crates/trusted-proxies/src/cloud/mod.rs +++ b/crates/trusted-proxies/src/cloud/mod.rs @@ -12,8 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Cloud service integration module +//! +//! This module provides integration with various cloud providers +//! for automatic IP range detection and metadata fetching. + mod detector; -mod metadata; +pub mod metadata; mod ranges; +pub use detector::*; pub use ranges::*; + +// Re-export metadata module types +pub use metadata::*; diff --git a/crates/trusted-proxies/src/cloud/ranges.rs b/crates/trusted-proxies/src/cloud/ranges.rs index e455398b..c429b51a 100644 --- a/crates/trusted-proxies/src/cloud/ranges.rs +++ b/crates/trusted-proxies/src/cloud/ranges.rs @@ -12,872 +12,207 @@ // See the License for the specific language governing permissions and // limitations under the License. -// src/cloud/metadata.rs +//! Cloud provider IP range definitions + +use std::str::FromStr; +use std::time::Duration; -use async_trait::async_trait; use ipnetwork::IpNetwork; use reqwest::Client; -use serde::Deserialize; -use std::net::Ipv4Addr; -use std::time::Duration; -use thiserror::Error; -use tracing::{debug, info, warn}; - -/// Error in obtaining cloud service provider metadata -#[derive(Error, Debug)] -pub enum CloudMetadataError { - #[error("HTTP request failed: {0}")] - HttpRequestFailed(#[from] reqwest::Error), - - #[error("JSON parsing fails: {0}")] - JsonParseError(#[from] serde_json::Error), - - #[error("Metadata Service Unavailable: {0}")] - MetadataUnavailable(String), - - #[error("IP address resolution failed: {0}")] - IpParseError(String), - - #[error("Unsupported cloud service providers: {0}")] - UnsupportedProvider(String), - - #[error("Misconfiguration: {0}")] - ConfigurationError(String), -} - -/// Cloud service provider type -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum CloudProvider { - Aws, - Azure, - Gcp, - DigitalOcean, - Vultr, - Linode, - Oracle, - Alibaba, - Tencent, - /// Unknown or customized - Unknown(String), -} - -impl CloudProvider { - /// Automatically detect cloud service providers from environment variables - pub fn detect_from_env() -> Option { - // Check various cloud service provider-specific environment variables - - // AWS - if std::env::var("AWS_EXECUTION_ENV").is_ok() - || std::env::var("AWS_REGION").is_ok() - || std::env::var("EC2_INSTANCE_ID").is_ok() - { - return Some(Self::Aws); - } - - // Azure - if std::env::var("WEBSITE_SITE_NAME").is_ok() - || std::env::var("WEBSITE_INSTANCE_ID").is_ok() - || std::env::var("APPSETTING_WEBSITE_SITE_NAME").is_ok() - { - return Some(Self::Azure); - } - - // GCP - if std::env::var("GCP_PROJECT").is_ok() - || std::env::var("GOOGLE_CLOUD_PROJECT").is_ok() - || std::env::var("GAE_INSTANCE").is_ok() - { - return Some(Self::Gcp); - } - - // DigitalOcean - if std::env::var("DIGITALOCEAN_REGION").is_ok() { - return Some(Self::DigitalOcean); - } - - // Vultr - if std::env::var("VULTR_REGION").is_ok() { - return Some(Self::Vultr); - } - - // Linode - if std::env::var("LINODE_REGION").is_ok() { - return Some(Self::Linode); - } - - // Oracle - if std::env::var("OCI_REGION").is_ok() { - return Some(Self::Oracle); - } - - // Alibaba Cloud - if std::env::var("ALIBABA_CLOUD_REGION").is_ok() { - return Some(Self::Alibaba); - } - - // Tencent Cloud - if std::env::var("TENCENTCLOUD_REGION").is_ok() { - return Some(Self::Tencent); - } - - None - } - - /// Get the cloud service provider name - pub fn name(&self) -> &str { - match self { - Self::Aws => "aws", - Self::Azure => "azure", - Self::Gcp => "gcp", - Self::DigitalOcean => "digitalocean", - Self::Vultr => "vultr", - Self::Linode => "linode", - Self::Oracle => "oracle", - Self::Alibaba => "alibaba", - Self::Tencent => "tencent", - Self::Unknown(name) => name, - } - } -} - -/// Cloud metadata fetcher characteristics -#[async_trait] -pub trait CloudMetadataFetcher { - /// Get the cloud service provider name - fn provider_name(&self) -> &str; - - /// Gets the CIDR range of the network where the instance is located - async fn fetch_network_cidrs(&self) -> Result, CloudMetadataError>; - - /// Get the cloud service provider's public IP range (e.g., load balancer, NAT gateway, etc.) - async fn fetch_public_ip_ranges(&self) -> Result, CloudMetadataError>; - - /// Get the IP range of a trusted proxy - async fn fetch_trusted_proxy_ranges(&self) -> Result, CloudMetadataError> { - // Default implementation: Merge network CIDR and public IP ranges - let mut ranges = Vec::new(); - - match self.fetch_network_cidrs().await { - Ok(cidrs) => ranges.extend(cidrs), - Err(e) => warn!("Get network CIDR failed: {}", e), - } - - match self.fetch_public_ip_ranges().await { - Ok(public_ranges) => ranges.extend(public_ranges), - Err(e) => warn!("Failed to get public IP range: {}", e), - } - - Ok(ranges) - } -} - -/// AWS Metadata Fetcher -pub struct AwsMetadataFetcher { - client: Client, - metadata_endpoint: String, -} - -impl Default for AwsMetadataFetcher { - fn default() -> Self { - Self::new() - } -} - -impl AwsMetadataFetcher { - pub fn new() -> Self { - let client = Client::builder() - .timeout(Duration::from_secs(2)) - .build() - .unwrap_or_else(|_| Client::new()); - - Self { - client, - metadata_endpoint: "http://169.254.169.254".to_string(), - } - } - - /// Get an IMDSv2 token - async fn get_metadata_token(&self) -> Result { - let url = format!("{}/latest/api/token", self.metadata_endpoint); - - let response = self - .client - .put(&url) - .header("X-aws-ec2-metadata-token-ttl-seconds", "21600") - .send() - .await - .map_err(|e| { - debug!("AWS IMDSv2 token acquisition failed, try IMDSv1: {}", e); - CloudMetadataError::MetadataUnavailable(format!("IMDSv2 failed: {}", e)) - })?; - - if response.status().is_success() { - let token = response.text().await?; - Ok(token) - } else { - Err(CloudMetadataError::MetadataUnavailable("Unable to obtain IMDSv2 tokens".to_string())) - } - } - - /// Use tokens to get metadata - async fn get_metadata_with_token(&self, path: &str, token: Option<&str>) -> Result { - let url = format!("{}/latest/{}", self.metadata_endpoint, path); - - let mut request = self.client.get(&url); - - if let Some(t) = token { - request = request.header("X-aws-ec2-metadata-token", t); - } - - let response = request.send().await?; - - if response.status().is_success() { - let text = response.text().await?; - Ok(text) - } else { - Err(CloudMetadataError::MetadataUnavailable(format!( - "Metadata path {} returns status:{}", - path, - response.status() - ))) - } - } - - /// Get a list of MAC addresses - async fn get_mac_addresses(&self, token: Option<&str>) -> Result, CloudMetadataError> { - let text = self - .get_metadata_with_token("meta-data/network/interfaces/macs/", token) - .await?; - - let macs: Vec = text - .lines() - .map(|line| line.trim().trim_end_matches('/')) - .filter(|mac| !mac.is_empty()) - .map(String::from) - .collect(); - - Ok(macs) - } - - /// Get the VPC CIDR block - async fn get_vpc_cidr_blocks(&self, token: Option<&str>) -> Result, CloudMetadataError> { - let macs = self.get_mac_addresses(token).await?; - let mut cidrs = Vec::new(); - - for mac in macs { - let path = format!("meta-data/network/interfaces/macs/{}/vpc-ipv4-cidr-block", mac); - - match self.get_metadata_with_token(&path, token).await { - Ok(cidr_text) => { - let cidr_text = cidr_text.trim(); - if let Ok(network) = cidr_text.parse::() { - cidrs.push(network); - debug!("To get a VPC CIDR: {}", network); - } - } - Err(e) => { - debug!("Unable to obtain VPC CIDR for MAC {}: {}", mac, e); - } - } - } - - Ok(cidrs) - } - - /// Get the subnet CIDR block - async fn get_subnet_cidr_blocks(&self, token: Option<&str>) -> Result, CloudMetadataError> { - let macs = self.get_mac_addresses(token).await?; - let mut cidrs = Vec::new(); - - for mac in macs { - let path = format!("meta-data/network/interfaces/macs/{}/subnet-ipv4-cidr-block", mac); - - match self.get_metadata_with_token(&path, token).await { - Ok(cidr_text) => { - let cidr_text = cidr_text.trim(); - if let Ok(network) = cidr_text.parse::() { - cidrs.push(network); - debug!("Get the subnet CIDR: {}", network); - } - } - Err(e) => { - debug!("The subnet CIDR for MAC cannot be obtained {}: {}", mac, e); - } - } - } - - Ok(cidrs) - } - - /// Get public IP ranges for AWS (from official sources) - async fn get_aws_public_ip_ranges(&self) -> Result, CloudMetadataError> { - let url = "https://ip-ranges.amazonaws.com/ip-ranges.json"; - - #[derive(Debug, Deserialize)] - struct AwsIpRanges { - prefixes: Vec, - } - - #[derive(Debug, Deserialize)] - struct AwsPrefix { - ip_prefix: String, - region: String, - service: String, - } - - let response = self.client.get(url).timeout(Duration::from_secs(5)).send().await?; - - if response.status().is_success() { - let ip_ranges: AwsIpRanges = response.json().await?; - - let mut ranges = Vec::new(); - for prefix in ip_ranges.prefixes { - // Include only service-specific IP ranges (e.g., EC2, CLOUDFRONT, etc.) - if matches!(prefix.service.as_str(), "EC2" | "CLOUDFRONT" | "ROUTE53" | "ROUTE53_HEALTHCHECKS") - && let Ok(network) = prefix.ip_prefix.parse::() - { - ranges.push(network); - } - } - - info!("{} public IP ranges are obtained from AWS officially", ranges.len()); - Ok(ranges) - } else { - Err(CloudMetadataError::MetadataUnavailable(format!( - "AWS IP Ranges API returns status:{}", - response.status() - ))) - } - } -} - -#[async_trait] -impl CloudMetadataFetcher for AwsMetadataFetcher { - fn provider_name(&self) -> &str { - "aws" - } - - async fn fetch_network_cidrs(&self) -> Result, CloudMetadataError> { - let mut cidrs = Vec::new(); - - // 尝试获取 IMDSv2 令牌 - let token = match self.get_metadata_token().await { - Ok(t) => Some(t), - Err(_) => { - debug!("Using IMDSv1 (no token)"); - None - } - }; - - // 获取 VPC CIDR - match self.get_vpc_cidr_blocks(token.as_deref()).await { - Ok(vpc_cidrs) => cidrs.extend(vpc_cidrs), - Err(e) => debug!("Failed to obtain VPC CIDR:{}", e), - } - - // 获取子网 CIDR - match self.get_subnet_cidr_blocks(token.as_deref()).await { - Ok(subnet_cidrs) => cidrs.extend(subnet_cidrs), - Err(e) => debug!("Failed to get subnet CIDR: {}", e), - } - - if cidrs.is_empty() { - Err(CloudMetadataError::MetadataUnavailable("No network CIDR can be obtained".to_string())) - } else { - info!("{} network CIDRs were obtained from AWS metadata", cidrs.len()); - Ok(cidrs) - } - } - - async fn fetch_public_ip_ranges(&self) -> Result, CloudMetadataError> { - self.get_aws_public_ip_ranges().await - } -} - -/// Azure Metadata Fetcher -pub struct AzureMetadataFetcher { - client: Client, - metadata_endpoint: String, -} - -impl Default for AzureMetadataFetcher { - fn default() -> Self { - Self::new() - } -} - -impl AzureMetadataFetcher { - pub fn new() -> Self { - let client = Client::builder() - .timeout(Duration::from_secs(2)) - .build() - .unwrap_or_else(|_| Client::new()); - - Self { - client, - metadata_endpoint: "http://169.254.169.254".to_string(), - } - } - - /// Get Azure metadata - async fn get_metadata(&self, path: &str) -> Result { - let url = format!("{}/metadata/{}?api-version=2021-05-01", self.metadata_endpoint, path); - - let response = self.client.get(&url).header("Metadata", "true").send().await?; - - if response.status().is_success() { - let text = response.text().await?; - Ok(text) - } else { - Err(CloudMetadataError::MetadataUnavailable(format!( - "Azure metadata path {} Return status: {}", - path, - response.status() - ))) - } - } - - /// Get Azure public IP ranges - async fn get_azure_public_ip_ranges(&self) -> Result, CloudMetadataError> { - // Azure Public IP Range Download URL - let urls = [ - "https://www.microsoft.com/en-us/download/confirmation.aspx?id=56519", - "https://download.microsoft.com/download/7/1/D/71D86715-5596-4529-9B13-DA13A5DE5B63/ServiceTags_Public_20231211.json", +use tracing::{debug, info}; + +use crate::error::AppError; + +/// Cloudflare IP 范围 +pub struct CloudflareIpRanges; + +impl CloudflareIpRanges { + /// 获取 Cloudflare IP 范围 + pub async fn fetch() -> Result, AppError> { + let ranges = vec![ + // IPv4 ranges + "103.21.244.0/22", + "103.22.200.0/22", + "103.31.4.0/22", + "104.16.0.0/13", + "104.24.0.0/14", + "108.162.192.0/18", + "131.0.72.0/22", + "141.101.64.0/18", + "162.158.0.0/15", + "172.64.0.0/13", + "173.245.48.0/20", + "188.114.96.0/20", + "190.93.240.0/20", + "197.234.240.0/22", + "198.41.128.0/17", + // IPv6 ranges + "2400:cb00::/32", + "2606:4700::/32", + "2803:f800::/32", + "2405:b500::/32", + "2405:8100::/32", + "2a06:98c0::/29", + "2c0f:f248::/32", ]; - for url in urls.iter() { - match self.fetch_azure_ip_ranges_from_url(url).await { - Ok(ranges) => { - info!("{} public IP ranges are downloaded from Azure", ranges.len()); - return Ok(ranges); - } - Err(e) => { - debug!("Failed to get Azure IP range from {}: {}", url, e); - } - } - } + let networks: Result, _> = ranges.into_iter().map(|s| IpNetwork::from_str(s)).collect(); - Err(CloudMetadataError::MetadataUnavailable( - "Azure public IP ranges cannot be obtained".to_string(), - )) + match networks { + Ok(networks) => { + info!("Loaded {} Cloudflare IP ranges", networks.len()); + Ok(networks) + } + Err(e) => Err(AppError::cloud(format!("Failed to parse Cloudflare IP ranges: {}", e))), + } } - async fn fetch_azure_ip_ranges_from_url(&self, url: &str) -> Result, CloudMetadataError> { - #[derive(Debug, Deserialize)] - struct AzureServiceTags { - values: Vec, - } + /// 从 Cloudflare API 获取 IP 范围 + pub async fn fetch_from_api() -> Result, AppError> { + let client = Client::builder() + .timeout(Duration::from_secs(10)) + .build() + .map_err(|e| AppError::cloud(format!("Failed to create HTTP client: {}", e)))?; - #[derive(Debug, Deserialize)] - struct AzureServiceTag { - properties: AzureServiceTagProperties, - } + let urls = ["https://www.cloudflare.com/ips-v4", "https://www.cloudflare.com/ips-v6"]; - #[derive(Debug, Deserialize)] - struct AzureServiceTagProperties { - address_prefixes: Vec, - } + let mut all_ranges = Vec::new(); - let response = self.client.get(url).timeout(Duration::from_secs(10)).send().await?; + for url in urls { + match client.get(url).send().await { + Ok(response) => { + if response.status().is_success() { + let text = response + .text() + .await + .map_err(|e| AppError::cloud(format!("Failed to read response from {}: {}", url, e)))?; - if response.status().is_success() { - let service_tags: AzureServiceTags = response.json().await?; + let ranges: Result, _> = text + .lines() + .map(|line| line.trim()) + .filter(|line| !line.is_empty()) + .map(|line| IpNetwork::from_str(line)) + .collect(); - let mut ranges = Vec::new(); - for tag in service_tags.values { - for prefix in tag.properties.address_prefixes { - if let Ok(network) = prefix.parse::() { - ranges.push(network); + match ranges { + Ok(mut networks) => { + debug!("Fetched {} IP ranges from {}", networks.len(), url); + all_ranges.append(&mut networks); + } + Err(e) => { + debug!("Failed to parse IP ranges from {}: {}", url, e); + } + } + } else { + debug!("Failed to fetch IP ranges from {}: {}", url, response.status()); } } - } - - Ok(ranges) - } else { - Err(CloudMetadataError::MetadataUnavailable(format!( - "Azure IP range URL returns status:{}", - response.status() - ))) - } - } -} - -#[async_trait] -impl CloudMetadataFetcher for AzureMetadataFetcher { - fn provider_name(&self) -> &str { - "azure" - } - - async fn fetch_network_cidrs(&self) -> Result, CloudMetadataError> { - // Azure metadata provides network interface information - let metadata = self.get_metadata("instance/network/interface").await?; - - #[derive(Debug, Deserialize)] - struct AzureNetworkInterface { - ipv4: AzureIpv4Info, - } - - #[derive(Debug, Deserialize)] - struct AzureIpv4Info { - subnet: Vec, - } - - #[derive(Debug, Deserialize)] - struct AzureSubnet { - address: String, - prefix: String, - } - - let interfaces: Vec = serde_json::from_str(&metadata)?; - - let mut cidrs = Vec::new(); - for interface in interfaces { - for subnet in interface.ipv4.subnet { - let cidr = format!("{}/{}", subnet.address, subnet.prefix); - if let Ok(network) = cidr.parse::() { - cidrs.push(network); + Err(e) => { + debug!("Failed to fetch from {}: {}", url, e); } } } - if cidrs.is_empty() { - Err(CloudMetadataError::MetadataUnavailable( - "Azure network CIDR can't be obtained".to_string(), - )) + if all_ranges.is_empty() { + // 如果 API 失败,回退到静态列表 + Self::fetch().await } else { - info!("{} network CIDRs were obtained from Azure metadata", cidrs.len()); - Ok(cidrs) + info!("Fetched {} Cloudflare IP ranges from API", all_ranges.len()); + Ok(all_ranges) } } +} - async fn fetch_public_ip_ranges(&self) -> Result, CloudMetadataError> { - self.get_azure_public_ip_ranges().await +/// DigitalOcean IP 范围 +pub struct DigitalOceanIpRanges; + +impl DigitalOceanIpRanges { + /// 获取 DigitalOcean IP 范围 + pub async fn fetch() -> Result, AppError> { + // DigitalOcean 的 IP 范围相对稳定,使用静态列表 + let ranges = vec![ + // 数据中心 IP 范围 + "64.227.0.0/16", + "138.197.0.0/16", + "139.59.0.0/16", + "157.230.0.0/16", + "159.65.0.0/16", + "167.99.0.0/16", + "178.128.0.0/16", + "206.189.0.0/16", + "207.154.0.0/16", + "209.97.0.0/16", + // 负载均衡器 IP 范围 + "144.126.0.0/16", + "143.198.0.0/16", + "161.35.0.0/16", + ]; + + let networks: Result, _> = ranges.into_iter().map(|s| IpNetwork::from_str(s)).collect(); + + match networks { + Ok(networks) => { + info!("Loaded {} DigitalOcean IP ranges", networks.len()); + Ok(networks) + } + Err(e) => Err(AppError::cloud(format!("Failed to parse DigitalOcean IP ranges: {}", e))), + } } } -/// GCP metadata fetcher -pub struct GcpMetadataFetcher { - client: Client, - metadata_endpoint: String, -} +/// Google Cloud IP 范围 +pub struct GoogleCloudIpRanges; -impl Default for GcpMetadataFetcher { - fn default() -> Self { - Self::new() - } -} - -impl GcpMetadataFetcher { - pub fn new() -> Self { +impl GoogleCloudIpRanges { + /// 从 Google API 获取 IP 范围 + pub async fn fetch() -> Result, AppError> { let client = Client::builder() - .timeout(Duration::from_secs(2)) + .timeout(Duration::from_secs(10)) .build() - .unwrap_or_else(|_| Client::new()); + .map_err(|e| AppError::cloud(format!("Failed to create HTTP client: {}", e)))?; - Self { - client, - metadata_endpoint: "http://metadata.google.internal".to_string(), - } - } - - /// Get GCP metadata - async fn get_metadata(&self, path: &str) -> Result { - let url = format!("{}/computeMetadata/v1/{}", self.metadata_endpoint, path); - - let response = self.client.get(&url).header("Metadata-Flavor", "Google").send().await?; - - if response.status().is_success() { - let text = response.text().await?; - Ok(text) - } else { - Err(CloudMetadataError::MetadataUnavailable(format!( - "GCP metadata path {} returns status:{}", - path, - response.status() - ))) - } - } - - /// Get GCP public IP ranges - async fn get_gcp_public_ip_ranges(&self) -> Result, CloudMetadataError> { let url = "https://www.gstatic.com/ipranges/cloud.json"; - #[derive(Debug, Deserialize)] - struct GcpIpRanges { - prefixes: Vec, + #[derive(Debug, serde::Deserialize)] + struct GoogleIpRanges { + prefixes: Vec, } - #[derive(Debug, Deserialize)] - struct GcpPrefix { + #[derive(Debug, serde::Deserialize)] + struct GooglePrefix { ipv4_prefix: Option, ipv6_prefix: Option, } - let response = self.client.get(url).timeout(Duration::from_secs(5)).send().await?; + match client.get(url).send().await { + Ok(response) => { + if response.status().is_success() { + let ip_ranges: GoogleIpRanges = response + .json() + .await + .map_err(|e| AppError::cloud(format!("Failed to parse Google IP ranges: {}", e)))?; - if response.status().is_success() { - let ip_ranges: GcpIpRanges = response.json().await?; + let mut networks = Vec::new(); - let mut ranges = Vec::new(); - for prefix in ip_ranges.prefixes { - if let Some(ipv4_prefix) = prefix.ipv4_prefix - && let Ok(network) = ipv4_prefix.parse::() - { - ranges.push(network); - } - } + for prefix in ip_ranges.prefixes { + if let Some(ipv4_prefix) = prefix.ipv4_prefix { + if let Ok(network) = IpNetwork::from_str(&ipv4_prefix) { + networks.push(network); + } + } + } - info!("{} public IP ranges were obtained from GCP officials", ranges.len()); - Ok(ranges) - } else { - Err(CloudMetadataError::MetadataUnavailable(format!( - "GCP IP Range API returns status:{}", - response.status() - ))) - } - } -} - -#[async_trait] -impl CloudMetadataFetcher for GcpMetadataFetcher { - fn provider_name(&self) -> &str { - "gcp" - } - - async fn fetch_network_cidrs(&self) -> Result, CloudMetadataError> { - // Get network interface information - let metadata = self.get_metadata("instance/network-interfaces/").await?; - - let interface_indices: Vec = metadata - .lines() - .filter_map(|line| { - let line = line.trim().trim_end_matches('/'); - if line.chars().all(|c| c.is_ascii_digit()) { - line.parse().ok() + info!("Fetched {} Google Cloud IP ranges from API", networks.len()); + Ok(networks) } else { - None - } - }) - .collect(); - - let mut cidrs = Vec::new(); - - for index in interface_indices { - // Get the subnet range - let subnet_path = format!("instance/network-interfaces/{}/subnetworks", index); - match self.get_metadata(&subnet_path).await { - Ok(_subnet_metadata) => { - // Subnet metadata may contain CIDR information - // Simplified processing: we get IP addresses and netmasks - let ip_path = format!("instance/network-interfaces/{}/ip", index); - let mask_path = format!("instance/network-interfaces/{}/subnetmask", index); - - if let (Ok(ip), Ok(mask)) = tokio::join!(self.get_metadata(&ip_path), self.get_metadata(&mask_path)) - && let (Ok(ip_addr), Ok(mask_len)) = (ip.trim().parse::(), mask_to_prefix_length(&mask)) - && let Ok(network) = format!("{}/{}", ip_addr, mask_len).parse::() - { - cidrs.push(network); - } - } - Err(e) => { - debug!("Failed to get GCP subnet information: {}", e); + debug!("Failed to fetch Google IP ranges: {}", response.status()); + Ok(Vec::new()) } } - } - - if cidrs.is_empty() { - Err(CloudMetadataError::MetadataUnavailable("GCP network CIDR is not available".to_string())) - } else { - info!("{} network CIDRs were obtained from GCP metadata", cidrs.len()); - Ok(cidrs) - } - } - - async fn fetch_public_ip_ranges(&self) -> Result, CloudMetadataError> { - self.get_gcp_public_ip_ranges().await - } -} - -/// Convert the subnet mask to the prefix length -pub fn mask_to_prefix_length(mask: &str) -> Result { - let mask_parts: Vec<&str> = mask.split('.').collect(); - if mask_parts.len() != 4 { - return Err(CloudMetadataError::IpParseError(format!("Invalid subnet masks:{}", mask))); - } - - let mut prefix_length = 0; - for part in mask_parts { - let octet: u8 = part - .parse() - .map_err(|_| CloudMetadataError::IpParseError(format!("Invalid mask octet:{}", part)))?; - - let mut remaining = octet; - while remaining > 0 { - if remaining & 0x80 == 0x80 { - prefix_length += 1; - remaining <<= 1; - } else { - break; + Err(e) => { + debug!("Failed to fetch Google IP ranges: {}", e); + Ok(Vec::new()) } } - - if remaining != 0 { - return Err(CloudMetadataError::IpParseError("Non-contiguous subnet masks".to_string())); - } - } - - Ok(prefix_length) -} - -/// Universal Cloud Metadata Fetcher (Auto-Detect) -pub struct CloudMetadataDetector { - client: Client, - provider: Option, -} - -impl Default for CloudMetadataDetector { - fn default() -> Self { - Self::new() - } -} - -impl CloudMetadataDetector { - pub fn new() -> Self { - let client = Client::builder() - .timeout(Duration::from_secs(3)) - .build() - .unwrap_or_else(|_| Client::new()); - - let provider = CloudProvider::detect_from_env(); - - if let Some(p) = &provider { - info!("Cloud service provider detected:{}", p.name()); - } else { - info!("The cloud service provider is not detected, and it may be running on-premises or in an unknown environment"); - } - - Self { client, provider } - } - - /// Create a fetcher for a specific cloud service provider - pub fn create_fetcher(&self) -> Option> { - match self.provider { - Some(CloudProvider::Aws) => Some(Box::new(AwsMetadataFetcher::new())), - Some(CloudProvider::Azure) => Some(Box::new(AzureMetadataFetcher::new())), - Some(CloudProvider::Gcp) => Some(Box::new(GcpMetadataFetcher::new())), - Some(CloudProvider::DigitalOcean) => { - // DigitalOcean has a similar implementation - None - } - _ => None, - } - } - - /// Try metadata endpoints for all cloud providers - pub async fn try_all_providers(&self) -> Result, CloudMetadataError> { - let providers: Vec> = vec![ - Box::new(AwsMetadataFetcher::new()), - Box::new(AzureMetadataFetcher::new()), - Box::new(GcpMetadataFetcher::new()), - ]; - - for provider in providers { - let provider_name = provider.provider_name(); - debug!("Try getting metadata from {}", provider_name); - - match provider.fetch_trusted_proxy_ranges().await { - Ok(ranges) => { - if !ranges.is_empty() { - info!("{} IP ranges are obtained from {}", provider_name, ranges.len()); - return Ok(ranges); - } - } - Err(e) => { - debug!("Failed to get metadata from {}: {}", provider_name, e); - } - } - } - - Err(CloudMetadataError::MetadataUnavailable( - "All cloud service provider metadata fetching fails".to_string(), - )) - } -} - -/// Main export function - Get the IP range from the cloud service provider -pub async fn fetch_cloud_provider_ips() -> Result, CloudMetadataError> { - let detector = CloudMetadataDetector::new(); - - let ip_ranges = if let Some(fetcher) = detector.create_fetcher() { - // Use a detected cloud service provider - fetcher.fetch_trusted_proxy_ranges().await? - } else { - // Try all cloud service providers - detector.try_all_providers().await? - }; - - // Convert to a list of strings - let result: Vec = ip_ranges.into_iter().map(|network| network.to_string()).collect(); - - Ok(result) -} - -/// Asynchronously Obtaining CSP IP Ranges (with Timeout) -pub async fn fetch_cloud_provider_ips_with_timeout(timeout_secs: u64) -> Result, CloudMetadataError> { - tokio::time::timeout(Duration::from_secs(timeout_secs), fetch_cloud_provider_ips()) - .await - .map_err(|_| CloudMetadataError::MetadataUnavailable("Metadata fetch timeout".to_string()))? -} - -/// Synchronous version (used in a synchronous context) -pub fn fetch_cloud_provider_ips_sync(timeout_secs: u64) -> Result, CloudMetadataError> { - let runtime = tokio::runtime::Runtime::new() - .map_err(|e| CloudMetadataError::ConfigurationError(format!("Unable to create runtime: {}", e)))?; - - runtime.block_on(fetch_cloud_provider_ips_with_timeout(timeout_secs)) -} - -#[cfg(test)] -mod tests { - #[test] - fn test_mask_to_prefix_length() { - use crate::cloud::mask_to_prefix_length; - - assert_eq!(mask_to_prefix_length("255.255.255.0").unwrap(), 24); - assert_eq!(mask_to_prefix_length("255.255.0.0").unwrap(), 16); - assert_eq!(mask_to_prefix_length("255.0.0.0").unwrap(), 8); - assert_eq!(mask_to_prefix_length("255.255.255.252").unwrap(), 30); - - // Invalid masks should fail - assert!(mask_to_prefix_length("255.255.255.1").is_err()); - assert!(mask_to_prefix_length("invalid").is_err()); - } - - #[tokio::test] - async fn test_cloud_metadata_fallback() { - use crate::cloud::ranges::CloudMetadataDetector; - - // In a test environment, the metadata service should not be available - let detector = CloudMetadataDetector::new(); - - // Trying all providers should fail (unless running tests in a real cloud environment) - let result = detector.try_all_providers().await; - assert!(result.is_err()); - } - - #[test] - fn test_cloud_ip_parsing() { - use ipnetwork::IpNetwork; - - // Test IP range resolution - let cidr: IpNetwork = "10.0.0.0/8".parse().unwrap(); - assert_eq!(cidr.prefix(), 8); - - let cidr: IpNetwork = "192.168.1.0/24".parse().unwrap(); - assert_eq!(cidr.prefix(), 24); - - // IPv6 - let cidr: IpNetwork = "2001:db8::/32".parse().unwrap(); - assert_eq!(cidr.prefix(), 32); } } diff --git a/crates/trusted-proxies/src/config/env.rs b/crates/trusted-proxies/src/config/env.rs index c31c842d..b41a5afd 100644 --- a/crates/trusted-proxies/src/config/env.rs +++ b/crates/trusted-proxies/src/config/env.rs @@ -12,282 +12,178 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Environment variable definitions for trusted agent configurations -//! -//! All configuration items are read by environment variables and support the following priorities: -//! 1. Environment Variables (Highest Priority) -//! 2. The default value set in the code -//! 3. Hard-coded values in the default implementation of the struct +//! Environment variable configuration constants and helpers -use crate::TrustedProxy; -use crate::cloud::fetch_cloud_provider_ips_sync; +use crate::error::ConfigError; use ipnetwork::IpNetwork; use std::str::FromStr; -use tracing::{debug, info, warn}; -// Environment variable key constant definition -// Format: RUSTFS_HTTP_{SECTION}_{KEY}, all caps, separated by underscores -// ==================== Agent configuration ==================== -/// Agent verification mode -pub const ENV_PROXY_VALIDATION_MODE: &str = "RUSTFS_HTTP_PROXY_VALIDATION_MODE"; +// ==================== 代理基础配置 ==================== +/// 代理验证模式 +pub const ENV_PROXY_VALIDATION_MODE: &str = "TRUSTED_PROXY_VALIDATION_MODE"; pub const DEFAULT_PROXY_VALIDATION_MODE: &str = "hop_by_hop"; -/// Whether to enable RFC 7239 Forwarded headers -pub const ENV_PROXY_ENABLE_RFC7239: &str = "RUSTFS_HTTP_PROXY_ENABLE_RFC7239"; +/// 是否启用 RFC 7239 Forwarded 头部 +pub const ENV_PROXY_ENABLE_RFC7239: &str = "TRUSTED_PROXY_ENABLE_RFC7239"; pub const DEFAULT_PROXY_ENABLE_RFC7239: bool = true; -/// Maximum number of proxy hops -pub const ENV_PROXY_MAX_PROXY_HOPS: &str = "RUSTFS_HTTP_PROXY_MAX_PROXY_HOPS"; -pub const DEFAULT_PROXY_MAX_PROXY_HOPS: usize = 10; +/// 最大代理跳数 +pub const ENV_PROXY_MAX_HOPS: &str = "TRUSTED_PROXY_MAX_HOPS"; +pub const DEFAULT_PROXY_MAX_HOPS: usize = 10; -/// whether chain continuity checking is enabled -pub const ENV_PROXY_ENABLE_CHAIN_CONTINUITY_CHECK: &str = "RUSTFS_HTTP_PROXY_ENABLE_CHAIN_CONTINUITY_CHECK"; -pub const DEFAULT_PROXY_ENABLE_CHAIN_CONTINUITY_CHECK: bool = true; +/// 是否启用链连续性检查 +pub const ENV_PROXY_CHAIN_CONTINUITY_CHECK: &str = "TRUSTED_PROXY_CHAIN_CONTINUITY_CHECK"; +pub const DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK: bool = true; -// ==================== Trusted agent configuration ==================== -/// Underlying Trusted Agent List (Comma-Separated IP/CIDR) -pub const ENV_TRUSTED_PROXIES: &str = "RUSTFS_HTTP_TRUSTED_PROXIES"; +/// 是否记录验证失败的请求 +pub const ENV_PROXY_LOG_FAILED_VALIDATIONS: &str = "TRUSTED_PROXY_LOG_FAILED_VALIDATIONS"; +pub const DEFAULT_PROXY_LOG_FAILED_VALIDATIONS: bool = true; + +// ==================== 可信代理配置 ==================== +/// 基础可信代理列表(逗号分隔的 IP/CIDR) +pub const ENV_TRUSTED_PROXIES: &str = "TRUSTED_PROXY_NETWORKS"; pub const DEFAULT_TRUSTED_PROXIES: &str = "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"; -/// Additional Trusted Agent List (production only, can be overridden) -pub const ENV_ADDITIONAL_TRUSTED_PROXIES: &str = "RUSTFS_HTTP_ADDITIONAL_TRUSTED_PROXIES"; -pub const DEFAULT_ADDITIONAL_TRUSTED_PROXIES: &str = ""; +/// 额外可信代理列表(生产环境专用,可覆盖) +pub const ENV_EXTRA_TRUSTED_PROXIES: &str = "TRUSTED_PROXY_EXTRA_NETWORKS"; +pub const DEFAULT_EXTRA_TRUSTED_PROXIES: &str = ""; -/// Allowed private networks (for internal proxy authentication) -pub const ENV_ALLOWED_PRIVATE_NETS: &str = "RUSTFS_HTTP_ALLOWED_PRIVATE_NETS"; -pub const DEFAULT_ALLOWED_PRIVATE_NETS: &str = "10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"; +/// 私有网络范围(用于内部代理验证) +pub const ENV_PRIVATE_NETWORKS: &str = "TRUSTED_PROXY_PRIVATE_NETWORKS"; +pub const DEFAULT_PRIVATE_NETWORKS: &str = "10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"; -// ==================== Cache configuration ==================== -/// Cache capacity -pub const ENV_CACHE_CAPACITY: &str = "RUSTFS_HTTP_CACHE_CAPACITY"; -pub const DEFAULT_CACHE_CAPACITY: usize = 10000; +// ==================== 缓存配置 ==================== +/// 缓存容量 +pub const ENV_CACHE_CAPACITY: &str = "TRUSTED_PROXY_CACHE_CAPACITY"; +pub const DEFAULT_CACHE_CAPACITY: usize = 10_000; -/// Cache TTL (seconds) -pub const ENV_CACHE_TTL_SECONDS: &str = "RUSTFS_HTTP_CACHE_TTL_SECONDS"; +/// 缓存 TTL(秒) +pub const ENV_CACHE_TTL_SECONDS: &str = "TRUSTED_PROXY_CACHE_TTL_SECONDS"; pub const DEFAULT_CACHE_TTL_SECONDS: u64 = 300; -/// Cache Cleanup Interval (Seconds) -pub const ENV_CACHE_CLEANUP_INTERVAL_SECONDS: &str = "RUSTFS_HTTP_CACHE_CLEANUP_INTERVAL_SECONDS"; -pub const DEFAULT_CACHE_CLEANUP_INTERVAL_SECONDS: u64 = 60; +/// 缓存清理间隔(秒) +pub const ENV_CACHE_CLEANUP_INTERVAL: &str = "TRUSTED_PROXY_CACHE_CLEANUP_INTERVAL"; +pub const DEFAULT_CACHE_CLEANUP_INTERVAL: u64 = 60; -// ==================== Monitor configuration ==================== -/// Whether monitoring metrics are enabled -pub const ENV_MONITORING_ENABLE_METRICS: &str = "RUSTFS_HTTP_MONITORING_ENABLE_METRICS"; -pub const DEFAULT_MONITORING_ENABLE_METRICS: bool = true; +// ==================== 监控配置 ==================== +/// 是否启用监控指标 +pub const ENV_METRICS_ENABLED: &str = "TRUSTED_PROXY_METRICS_ENABLED"; +pub const DEFAULT_METRICS_ENABLED: bool = true; -/// Log level -pub const ENV_MONITORING_LOG_LEVEL: &str = "RUSTFS_HTTP_MONITORING_LOG_LEVEL"; -pub const DEFAULT_MONITORING_LOG_LEVEL: &str = "info"; +/// 日志级别 +pub const ENV_LOG_LEVEL: &str = "TRUSTED_PROXY_LOG_LEVEL"; +pub const DEFAULT_LOG_LEVEL: &str = "info"; -/// Whether to log validation failures -pub const ENV_MONITORING_LOG_FAILED_VALIDATIONS: &str = "RUSTFS_HTTP_MONITORING_LOG_FAILED_VALIDATIONS"; -pub const DEFAULT_MONITORING_LOG_FAILED_VALIDATIONS: bool = true; +/// 是否启用结构化日志 +pub const ENV_STRUCTURED_LOGGING: &str = "TRUSTED_PROXY_STRUCTURED_LOGGING"; +pub const DEFAULT_STRUCTURED_LOGGING: bool = false; -/// Cloud service provider-specific IP ranges (Cloudflare, etc.) -pub const ENV_CLOUDFLARE_IPS_ENABLED: &str = "RUSTFS_HTTP_CLOUDFLARE_IPS_ENABLED"; +/// 是否启用请求追踪 +pub const ENV_TRACING_ENABLED: &str = "TRUSTED_PROXY_TRACING_ENABLED"; +pub const DEFAULT_TRACING_ENABLED: bool = true; + +// ==================== 云服务集成 ==================== +/// 是否启用云元数据获取 +pub const ENV_CLOUD_METADATA_ENABLED: &str = "TRUSTED_PROXY_CLOUD_METADATA_ENABLED"; +pub const DEFAULT_CLOUD_METADATA_ENABLED: bool = false; + +/// 云元数据获取超时(秒) +pub const ENV_CLOUD_METADATA_TIMEOUT: &str = "TRUSTED_PROXY_CLOUD_METADATA_TIMEOUT"; +pub const DEFAULT_CLOUD_METADATA_TIMEOUT: u64 = 5; + +/// 是否启用 Cloudflare IP 范围 +pub const ENV_CLOUDFLARE_IPS_ENABLED: &str = "TRUSTED_PROXY_CLOUDFLARE_IPS_ENABLED"; pub const DEFAULT_CLOUDFLARE_IPS_ENABLED: bool = false; -/// Cloud metadata configuration -pub const ENV_CLOUD_METADATA_ENABLED: &str = "RUSTFS_HTTP_CLOUD_METADATA_ENABLED"; -pub const DEFAULT_CLOUD_METADATA_ENABLED: bool = true; +/// 强制指定的云服务商(覆盖自动检测) +pub const ENV_CLOUD_PROVIDER_FORCE: &str = "TRUSTED_PROXY_CLOUD_PROVIDER_FORCE"; +pub const DEFAULT_CLOUD_PROVIDER_FORCE: &str = ""; -pub const ENV_CLOUD_METADATA_TIMEOUT_SECS: &str = "RUSTFS_HTTP_CLOUD_METADATA_TIMEOUT_SECS"; -pub const DEFAULT_CLOUD_METADATA_TIMEOUT_SECS: u64 = 5; +// ==================== 辅助函数 ==================== -pub const ENV_CLOUD_PROVIDER_FORCE: &str = "RUSTFS_HTTP_CLOUD_PROVIDER_FORCE"; -pub const DEFAULT_CLOUD_PROVIDER_FORCE: &str = ""; // Null means automatic detection +/// 从环境变量解析逗号分隔的IP/CIDR列表 +pub fn parse_ip_list_from_env(key: &str, default: &str) -> Result, ConfigError> { + let value = std::env::var(key).unwrap_or_else(|_| default.to_string()); -/// Environment variables resolve error types -#[derive(Debug, thiserror::Error)] -pub enum ConfigError { - #[error("Environment variable resolution failed: {0}")] - EnvParseError(String), - - #[error("IP/CIDR Format Error: {0}")] - IpFormatError(String), - - #[error("Boolean parsing failed: {0}")] - BoolParseError(String), - - #[error("Numeric parsing failed: {0}")] - NumberParseError(String), - - #[error("Enum value parsing failed: {0}")] - EnumParseError(String), -} - -/// Environment variables configure loaders -pub struct EnvConfigLoader; - -impl EnvConfigLoader { - /// Get the string value from the environment variable - pub fn get_string(key: &str, default: &str) -> String { - std::env::var(key).unwrap_or_else(|_| default.to_string()) + if value.trim().is_empty() { + return Ok(Vec::new()); } - /// Get Boolean values from environment variables - pub fn get_bool(key: &str, default: bool) -> Result { - let value = Self::get_string(key, if default { "true" } else { "false" }); - value - .parse() - .map_err(|_| ConfigError::BoolParseError(format!("{}={}", key, value))) - } - - /// Get an integer value from an environment variable - pub fn get_usize(key: &str, default: usize) -> Result { - let value = Self::get_string(key, &default.to_string()); - value - .parse() - .map_err(|e| ConfigError::NumberParseError(format!("{}={}: {}", key, value, e))) - } - - /// Get the u64 value from the environment variable - pub fn get_u64(key: &str, default: u64) -> Result { - let value = Self::get_string(key, &default.to_string()); - value - .parse() - .map_err(|e| ConfigError::NumberParseError(format!("{}={}: {}", key, value, e))) - } - - /// Parsing comma-separated IP/CIDR lists from environment variables - pub fn parse_ip_list(key: &str, default: &str) -> Result, ConfigError> { - let value = Self::get_string(key, default); - if value.trim().is_empty() { - return Ok(Vec::new()); + let mut networks = Vec::new(); + for item in value.split(',') { + let item = item.trim(); + if item.is_empty() { + continue; } - let mut proxies = Vec::new(); - for item in value.split(',') { - let item = item.trim(); - if item.is_empty() { - continue; - } - - // Attempt to resolve to CIDR - if item.contains('/') { - match IpNetwork::from_str(item) { - Ok(network) => proxies.push(TrustedProxy::Cidr(network)), - Err(e) => return Err(ConfigError::IpFormatError(format!("{}: {}: {}", key, item, e))), - } - } else { - // Attempt to resolve to a single IP - match item.parse() { - Ok(ip) => proxies.push(TrustedProxy::Single(ip)), - Err(e) => return Err(ConfigError::IpFormatError(format!("{}: {}: {}", key, item, e))), - } - } - } - - Ok(proxies) - } - - /// Parses comma-separated CIDR lists from environment variables - pub fn parse_cidr_list(key: &str, default: &str) -> Result, ConfigError> { - let value = Self::get_string(key, default); - if value.trim().is_empty() { - return Ok(Vec::new()); - } - - let mut networks = Vec::new(); - for item in value.split(',') { - let item = item.trim(); - if item.is_empty() { - continue; - } - - match IpNetwork::from_str(item) { - Ok(network) => networks.push(network), - Err(e) => return Err(ConfigError::IpFormatError(format!("{}: {}: {}", key, item, e))), - } - } - - Ok(networks) - } - - /// Get the validation schema enumeration value - pub fn get_validation_mode(key: &str, default: &str) -> Result { - let value = Self::get_string(key, default); - match value.to_lowercase().as_str() { - "lenient" => Ok(crate::advanced::ValidationMode::Lenient), - "strict" => Ok(crate::advanced::ValidationMode::Strict), - "hop_by_hop" => Ok(crate::advanced::ValidationMode::HopByHop), - _ => Err(ConfigError::EnumParseError(format!( - "{}: Must be 'lenient', 'strict' or 'hop_by_hop'", - value - ))), - } - } - - /// Get the log level - pub fn get_log_level(key: &str, default: &str) -> String { - let value = Self::get_string(key, default).to_lowercase(); - match value.as_str() { - "trace" | "debug" | "info" | "warn" | "error" => value, - _ => default.to_string(), - } - } - - /// Get the cloud metadata IP range - pub fn fetch_cloud_metadata_ips() -> Result, ConfigError> { - // Check if cloud metadata is enabled - let enabled = Self::get_bool(ENV_CLOUD_METADATA_ENABLED, DEFAULT_CLOUD_METADATA_ENABLED)?; - - if !enabled { - debug!("Cloud metadata fetching is disabled"); - return Ok(Vec::new()); - } - - let timeout_secs = Self::get_u64(ENV_CLOUD_METADATA_TIMEOUT_SECS, DEFAULT_CLOUD_METADATA_TIMEOUT_SECS)?; - - // If there is a mandatory designation of a cloud service provider, set the environment variable - let forced_provider = Self::get_string(ENV_CLOUD_PROVIDER_FORCE, DEFAULT_CLOUD_PROVIDER_FORCE); - if !forced_provider.is_empty() { - // std::env::set_var("CLOUD_PROVIDER_FORCE", forced_provider); - } - - match fetch_cloud_provider_ips_sync(timeout_secs) { - Ok(ips) => { - info!("{} IP ranges were obtained from cloud metadata", ips.len()); - if !ips.is_empty() { - debug!("Cloud IP range: {:?}", ips); - } - Ok(ips) - } + match IpNetwork::from_str(item) { + Ok(network) => networks.push(network), Err(e) => { - warn!("Cloud metadata fetch failed: {}", e); - Ok(Vec::new()) + tracing::warn!("Failed to parse network '{}' from {}: {}", item, key, e); } } } + + Ok(networks) } -/// Cloud service provider IP range configuration -pub struct CloudProviderIps { - /// Cloudflare IP range - pub cloudflare: Vec, +/// 从环境变量解析逗号分隔的字符串列表 +pub fn parse_string_list_from_env(key: &str, default: &str) -> Vec { + let value = std::env::var(key).unwrap_or_else(|_| default.to_string()); + + value + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() } -impl CloudProviderIps { - /// Get the Cloudflare IP range - pub fn cloudflare_ranges() -> Vec { - let ranges = vec![ - "103.21.244.0/22", - "103.22.200.0/22", - "103.31.4.0/22", - "104.16.0.0/13", - "104.24.0.0/14", - "108.162.192.0/18", - "131.0.72.0/22", - "141.101.64.0/18", - "162.158.0.0/15", - "172.64.0.0/13", - "173.245.48.0/20", - "188.114.96.0/20", - "190.93.240.0/20", - "197.234.240.0/22", - "198.41.128.0/17", - ]; - - ranges.into_iter().filter_map(|s| IpNetwork::from_str(s).ok()).collect() - } +/// 从环境变量获取布尔值 +pub fn get_bool_from_env(key: &str, default: bool) -> bool { + std::env::var(key) + .map(|v| match v.to_lowercase().as_str() { + "true" | "1" | "yes" | "on" => true, + "false" | "0" | "no" | "off" => false, + _ => default, + }) + .unwrap_or(default) +} + +/// 从环境变量获取整数值 +pub fn get_usize_from_env(key: &str, default: usize) -> usize { + std::env::var(key).ok().and_then(|v| v.parse().ok()).unwrap_or(default) +} + +/// 从环境变量获取 u64 值 +pub fn get_u64_from_env(key: &str, default: u64) -> u64 { + std::env::var(key).ok().and_then(|v| v.parse().ok()).unwrap_or(default) +} + +/// 从环境变量获取字符串值 +pub fn get_string_from_env(key: &str, default: &str) -> String { + std::env::var(key).unwrap_or_else(|_| default.to_string()) +} + +/// 检查环境变量是否已设置 +pub fn is_env_set(key: &str) -> bool { + std::env::var(key).is_ok() +} + +/// 获取所有与可信代理相关的环境变量(用于调试) +pub fn get_all_proxy_env_vars() -> Vec<(String, String)> { + let vars = [ + ENV_PROXY_VALIDATION_MODE, + ENV_PROXY_ENABLE_RFC7239, + ENV_PROXY_MAX_HOPS, + ENV_PROXY_CHAIN_CONTINUITY_CHECK, + ENV_TRUSTED_PROXIES, + ENV_EXTRA_TRUSTED_PROXIES, + ENV_CLOUD_METADATA_ENABLED, + ENV_CLOUD_METADATA_TIMEOUT, + ENV_CLOUDFLARE_IPS_ENABLED, + ]; + + vars.iter() + .filter_map(|&key| std::env::var(key).ok().map(|value| (key.to_string(), value))) + .collect() } diff --git a/crates/trusted-proxies/src/config/loader.rs b/crates/trusted-proxies/src/config/loader.rs index 6238cfff..c86da95d 100644 --- a/crates/trusted-proxies/src/config/loader.rs +++ b/crates/trusted-proxies/src/config/loader.rs @@ -11,3 +11,194 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Configuration loader for environment variables and files + +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; + +use crate::config::env::*; +use crate::config::{AppConfig, CacheConfig, CloudConfig, MonitoringConfig, TrustedProxy, TrustedProxyConfig, ValidationMode}; +use crate::error::ConfigError; + +/// 配置加载器 +#[derive(Debug, Clone)] +pub struct ConfigLoader; + +impl ConfigLoader { + /// 从环境变量加载完整应用配置 + pub fn from_env() -> Result { + // 加载可信代理配置 + let proxy_config = Self::load_proxy_config()?; + + // 加载缓存配置 + let cache_config = Self::load_cache_config(); + + // 加载监控配置 + let monitoring_config = Self::load_monitoring_config(); + + // 加载云服务配置 + let cloud_config = Self::load_cloud_config(); + + // 服务器地址 + let server_addr = Self::load_server_addr(); + + Ok(AppConfig::new(proxy_config, cache_config, monitoring_config, cloud_config, server_addr)) + } + + /// 加载可信代理配置 + fn load_proxy_config() -> Result { + // 解析可信代理列表 + let mut proxies = Vec::new(); + + // 基础可信代理 + let base_networks = parse_ip_list_from_env(ENV_TRUSTED_PROXIES, DEFAULT_TRUSTED_PROXIES)?; + for network in base_networks { + proxies.push(TrustedProxy::Cidr(network)); + } + + // 额外可信代理 + let extra_networks = parse_ip_list_from_env(ENV_EXTRA_TRUSTED_PROXIES, DEFAULT_EXTRA_TRUSTED_PROXIES)?; + for network in extra_networks { + proxies.push(TrustedProxy::Cidr(network)); + } + + // 单个 IP(从环境变量解析) + let ip_strings = parse_string_list_from_env("TRUSTED_PROXY_IPS", ""); + for ip_str in ip_strings { + if let Ok(ip) = ip_str.parse::() { + proxies.push(TrustedProxy::Single(ip)); + } + } + + // 验证模式 + let validation_mode_str = get_string_from_env(ENV_PROXY_VALIDATION_MODE, DEFAULT_PROXY_VALIDATION_MODE); + let validation_mode = ValidationMode::from_str(&validation_mode_str)?; + + // 其他配置 + let enable_rfc7239 = get_bool_from_env(ENV_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_ENABLE_RFC7239); + let max_hops = get_usize_from_env(ENV_PROXY_MAX_HOPS, DEFAULT_PROXY_MAX_HOPS); + let enable_chain_check = get_bool_from_env(ENV_PROXY_CHAIN_CONTINUITY_CHECK, DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK); + + // 私有网络 + let private_networks = parse_ip_list_from_env(ENV_PRIVATE_NETWORKS, DEFAULT_PRIVATE_NETWORKS)?; + + Ok(TrustedProxyConfig::new( + proxies, + validation_mode, + enable_rfc7239, + max_hops, + enable_chain_check, + private_networks, + )) + } + + /// 加载缓存配置 + fn load_cache_config() -> CacheConfig { + CacheConfig { + capacity: get_usize_from_env(ENV_CACHE_CAPACITY, DEFAULT_CACHE_CAPACITY), + ttl_seconds: get_u64_from_env(ENV_CACHE_TTL_SECONDS, DEFAULT_CACHE_TTL_SECONDS), + cleanup_interval_seconds: get_u64_from_env(ENV_CACHE_CLEANUP_INTERVAL, DEFAULT_CACHE_CLEANUP_INTERVAL), + } + } + + /// 加载监控配置 + fn load_monitoring_config() -> MonitoringConfig { + MonitoringConfig { + metrics_enabled: get_bool_from_env(ENV_METRICS_ENABLED, DEFAULT_METRICS_ENABLED), + log_level: get_string_from_env(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL), + structured_logging: get_bool_from_env(ENV_STRUCTURED_LOGGING, DEFAULT_STRUCTURED_LOGGING), + tracing_enabled: get_bool_from_env(ENV_TRACING_ENABLED, DEFAULT_TRACING_ENABLED), + log_failed_validations: get_bool_from_env(ENV_PROXY_LOG_FAILED_VALIDATIONS, DEFAULT_PROXY_LOG_FAILED_VALIDATIONS), + } + } + + /// 加载云服务配置 + fn load_cloud_config() -> CloudConfig { + let forced_provider_str = get_string_from_env(ENV_CLOUD_PROVIDER_FORCE, DEFAULT_CLOUD_PROVIDER_FORCE); + let forced_provider = if forced_provider_str.is_empty() { + None + } else { + Some(forced_provider_str) + }; + + CloudConfig { + metadata_enabled: get_bool_from_env(ENV_CLOUD_METADATA_ENABLED, DEFAULT_CLOUD_METADATA_ENABLED), + metadata_timeout_seconds: get_u64_from_env(ENV_CLOUD_METADATA_TIMEOUT, DEFAULT_CLOUD_METADATA_TIMEOUT), + cloudflare_ips_enabled: get_bool_from_env(ENV_CLOUDFLARE_IPS_ENABLED, DEFAULT_CLOUDFLARE_IPS_ENABLED), + forced_provider, + } + } + + /// 加载服务器地址 + fn load_server_addr() -> SocketAddr { + let host = get_string_from_env("SERVER_HOST", "0.0.0.0"); + let port = get_usize_from_env("SERVER_PORT", 3000) as u16; + + format!("{}:{}", host, port) + .parse() + .unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 3000))) + } + + /// 从环境变量加载配置,如果失败则使用默认值 + pub fn from_env_or_default() -> AppConfig { + match Self::from_env() { + Ok(config) => { + tracing::info!("Configuration loaded successfully from environment variables"); + config + } + Err(e) => { + tracing::warn!("Failed to load configuration from environment: {}. Using defaults", e); + Self::default_config() + } + } + } + + /// 创建默认配置 + pub fn default_config() -> AppConfig { + // 默认可信代理配置 + let proxy_config = TrustedProxyConfig::new( + vec![ + TrustedProxy::Single("127.0.0.1".parse().unwrap()), + TrustedProxy::Single("::1".parse().unwrap()), + ], + ValidationMode::HopByHop, + true, + 10, + true, + vec![ + "10.0.0.0/8".parse().unwrap(), + "172.16.0.0/12".parse().unwrap(), + "192.168.0.0/16".parse().unwrap(), + ], + ); + + // 默认应用配置 + AppConfig::new( + proxy_config, + CacheConfig::default(), + MonitoringConfig::default(), + CloudConfig::default(), + "0.0.0.0:3000".parse().unwrap(), + ) + } + + /// 打印配置摘要 + pub fn print_summary(config: &AppConfig) { + tracing::info!("=== Application Configuration ==="); + tracing::info!("Server: {}", config.server_addr); + tracing::info!("Trusted Proxies: {}", config.proxy.proxies.len()); + tracing::info!("Validation Mode: {:?}", config.proxy.validation_mode); + tracing::info!("Cache Capacity: {}", config.cache.capacity); + tracing::info!("Metrics Enabled: {}", config.monitoring.metrics_enabled); + tracing::info!("Cloud Metadata: {}", config.cloud.metadata_enabled); + + if config.monitoring.log_failed_validations { + tracing::info!("Failed validations will be logged"); + } + + if !config.proxy.proxies.is_empty() { + tracing::debug!("Trusted networks: {:?}", config.proxy.get_network_strings()); + } + } +} diff --git a/crates/trusted-proxies/src/config/mod.rs b/crates/trusted-proxies/src/config/mod.rs index 52da4fef..d7e9722b 100644 --- a/crates/trusted-proxies/src/config/mod.rs +++ b/crates/trusted-proxies/src/config/mod.rs @@ -17,3 +17,8 @@ mod loader; mod types; pub use env::*; +// Re-export commonly used types +pub use ipnetwork::IpNetwork; +pub use loader::*; +pub use std::net::{IpAddr, SocketAddr}; +pub use types::*; diff --git a/crates/trusted-proxies/src/config/types.rs b/crates/trusted-proxies/src/config/types.rs index 6238cfff..72802380 100644 --- a/crates/trusted-proxies/src/config/types.rs +++ b/crates/trusted-proxies/src/config/types.rs @@ -11,3 +11,286 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Configuration type definitions + +use ipnetwork::IpNetwork; +use serde::{Deserialize, Serialize}; +use std::net::{IpAddr, SocketAddr}; +use std::time::Duration; + +use crate::error::ConfigError; + +/// 代理验证模式 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ValidationMode { + /// 宽松模式:只要最后一个代理可信,就接受整个链 + Lenient, + /// 严格模式:要求链中所有代理都可信 + Strict, + /// 跳数验证模式:从右向左找到第一个不可信代理 + HopByHop, +} + +impl ValidationMode { + /// 从字符串解析验证模式 + pub fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "lenient" => Ok(Self::Lenient), + "strict" => Ok(Self::Strict), + "hop_by_hop" => Ok(Self::HopByHop), + _ => Err(ConfigError::InvalidConfig(format!( + "Invalid validation mode: '{}'. Must be one of: lenient, strict, hop_by_hop", + s + ))), + } + } + + /// 转换为字符串 + pub fn as_str(&self) -> &'static str { + match self { + Self::Lenient => "lenient", + Self::Strict => "strict", + Self::HopByHop => "hop_by_hop", + } + } +} + +impl Default for ValidationMode { + fn default() -> Self { + Self::HopByHop + } +} + +/// 可信代理类型 +#[derive(Debug, Clone)] +pub enum TrustedProxy { + /// 单个 IP 地址 + Single(IpAddr), + /// IP 地址段 (CIDR 表示法) + Cidr(IpNetwork), +} + +impl TrustedProxy { + /// 检查 IP 是否匹配此代理配置 + pub fn contains(&self, ip: &IpAddr) -> bool { + match self { + Self::Single(proxy_ip) => ip == proxy_ip, + Self::Cidr(network) => network.contains(*ip), + } + } + + /// 转换为字符串表示 + pub fn to_string(&self) -> String { + match self { + Self::Single(ip) => ip.to_string(), + Self::Cidr(network) => network.to_string(), + } + } +} + +/// 可信代理配置 +#[derive(Debug, Clone)] +pub struct TrustedProxyConfig { + /// 代理列表 + pub proxies: Vec, + /// 验证模式 + pub validation_mode: ValidationMode, + /// 是否启用 RFC 7239 Forwarded 头部 + pub enable_rfc7239: bool, + /// 最大代理跳数 + pub max_hops: usize, + /// 是否启用链连续性检查 + pub enable_chain_continuity_check: bool, + /// 私有网络范围 + pub private_networks: Vec, +} + +impl TrustedProxyConfig { + /// 创建新配置 + pub fn new( + proxies: Vec, + validation_mode: ValidationMode, + enable_rfc7239: bool, + max_hops: usize, + enable_chain_continuity_check: bool, + private_networks: Vec, + ) -> Self { + Self { + proxies, + validation_mode, + enable_rfc7239, + max_hops, + enable_chain_continuity_check, + private_networks, + } + } + + /// 检查 SocketAddr 是否来自可信代理 + pub fn is_trusted(&self, addr: &SocketAddr) -> bool { + let ip = addr.ip(); + self.proxies.iter().any(|proxy| proxy.contains(&ip)) + } + + /// 检查 IP 是否在私有网络范围内 + pub fn is_private_network(&self, ip: &IpAddr) -> bool { + self.private_networks.iter().any(|network| network.contains(*ip)) + } + + /// 获取所有网络范围的字符串表示(用于调试) + pub fn get_network_strings(&self) -> Vec { + self.proxies.iter().map(|p| p.to_string()).collect() + } + + /// 获取配置摘要 + pub fn summary(&self) -> String { + format!( + "TrustedProxyConfig {{ proxies: {}, mode: {}, max_hops: {} }}", + self.proxies.len(), + self.validation_mode.as_str(), + self.max_hops + ) + } +} + +/// 缓存配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CacheConfig { + /// 缓存容量 + pub capacity: usize, + /// 缓存 TTL(秒) + pub ttl_seconds: u64, + /// 缓存清理间隔(秒) + pub cleanup_interval_seconds: u64, +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + capacity: 10_000, + ttl_seconds: 300, + cleanup_interval_seconds: 60, + } + } +} + +impl CacheConfig { + /// 获取缓存 TTL 时长 + pub fn ttl_duration(&self) -> Duration { + Duration::from_secs(self.ttl_seconds) + } + + /// 获取缓存清理间隔时长 + pub fn cleanup_interval(&self) -> Duration { + Duration::from_secs(self.cleanup_interval_seconds) + } +} + +/// 监控配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MonitoringConfig { + /// 是否启用监控指标 + pub metrics_enabled: bool, + /// 日志级别 + pub log_level: String, + /// 是否启用结构化日志 + pub structured_logging: bool, + /// 是否启用请求追踪 + pub tracing_enabled: bool, + /// 是否记录验证失败的请求 + pub log_failed_validations: bool, +} + +impl Default for MonitoringConfig { + fn default() -> Self { + Self { + metrics_enabled: true, + log_level: "info".to_string(), + structured_logging: false, + tracing_enabled: true, + log_failed_validations: true, + } + } +} + +/// 云服务集成配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CloudConfig { + /// 是否启用云元数据获取 + pub metadata_enabled: bool, + /// 云元数据获取超时(秒) + pub metadata_timeout_seconds: u64, + /// 是否启用 Cloudflare IP 范围 + pub cloudflare_ips_enabled: bool, + /// 强制指定的云服务商 + pub forced_provider: Option, +} + +impl Default for CloudConfig { + fn default() -> Self { + Self { + metadata_enabled: false, + metadata_timeout_seconds: 5, + cloudflare_ips_enabled: false, + forced_provider: None, + } + } +} + +impl CloudConfig { + /// 获取元数据获取超时时长 + pub fn metadata_timeout(&self) -> Duration { + Duration::from_secs(self.metadata_timeout_seconds) + } +} + +/// 完整的应用配置 +#[derive(Debug, Clone)] +pub struct AppConfig { + /// 代理配置 + pub proxy: TrustedProxyConfig, + /// 缓存配置 + pub cache: CacheConfig, + /// 监控配置 + pub monitoring: MonitoringConfig, + /// 云服务配置 + pub cloud: CloudConfig, + /// 服务器绑定地址 + pub server_addr: SocketAddr, +} + +impl AppConfig { + /// 创建应用配置 + pub fn new( + proxy: TrustedProxyConfig, + cache: CacheConfig, + monitoring: MonitoringConfig, + cloud: CloudConfig, + server_addr: SocketAddr, + ) -> Self { + Self { + proxy, + cache, + monitoring, + cloud, + server_addr, + } + } + + /// 获取配置摘要 + pub fn summary(&self) -> String { + format!( + "AppConfig {{\n\ + \x20\x20proxy: {},\n\ + \x20\x20cache_capacity: {},\n\ + \x20\x20metrics: {},\n\ + \x20\x20cloud_metadata: {}\n\ + }}", + self.proxy.summary(), + self.cache.capacity, + self.monitoring.metrics_enabled, + self.cloud.metadata_enabled + ) + } +} diff --git a/crates/trusted-proxies/src/error/config.rs b/crates/trusted-proxies/src/error/config.rs index 6238cfff..6c91aa02 100644 --- a/crates/trusted-proxies/src/error/config.rs +++ b/crates/trusted-proxies/src/error/config.rs @@ -11,3 +11,72 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Configuration error types + +use std::net::AddrParseError; + +/// 配置错误类型 +#[derive(Debug, thiserror::Error)] +pub enum ConfigError { + /// 环境变量缺失 + #[error("Missing environment variable: {0}")] + MissingEnvVar(String), + + /// 环境变量解析失败 + #[error("Failed to parse environment variable {0}: {1}")] + EnvParseError(String, String), + + /// 无效的配置值 + #[error("Invalid configuration value for {0}: {1}")] + InvalidValue(String, String), + + /// 无效的 IP 地址或网络 + #[error("Invalid IP address or network: {0}")] + InvalidIp(String), + + /// 配置验证失败 + #[error("Configuration validation failed: {0}")] + ValidationFailed(String), + + /// 配置冲突 + #[error("Configuration conflict: {0}")] + Conflict(String), + + /// 配置文件错误 + #[error("Config file error: {0}")] + FileError(String), + + /// 无效的配置 + #[error("Invalid config: {0}")] + InvalidConfig(String), +} + +impl From for ConfigError { + fn from(err: AddrParseError) -> Self { + Self::InvalidIp(err.to_string()) + } +} + +impl From for ConfigError { + fn from(err: ipnetwork::IpNetworkError) -> Self { + Self::InvalidIp(err.to_string()) + } +} + +impl ConfigError { + /// 创建环境变量缺失错误 + pub fn missing_env_var(key: &str) -> Self { + Self::MissingEnvVar(key.to_string()) + } + + /// 创建环境变量解析错误 + pub fn env_parse(key: &str, value: &str) -> Self { + Self::EnvParseError(key.to_string(), value.to_string()) + } + + /// 创建无效配置值错误 + pub fn invalid_value(field: &str, value: &str) -> Self { + Self::InvalidValue(field.to_string(), value.to_string()) + } +} diff --git a/crates/trusted-proxies/src/error/mod.rs b/crates/trusted-proxies/src/error/mod.rs index 33de8995..4b3c7502 100644 --- a/crates/trusted-proxies/src/error/mod.rs +++ b/crates/trusted-proxies/src/error/mod.rs @@ -12,5 +12,83 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Error types for the trusted proxy system + mod config; mod proxy; + +pub use config::*; +pub use proxy::*; + +/// 统一错误类型 +#[derive(Debug, thiserror::Error)] +pub enum AppError { + /// 配置错误 + #[error("Configuration error: {0}")] + Config(#[from] ConfigError), + + /// 代理验证错误 + #[error("Proxy validation error: {0}")] + Proxy(#[from] ProxyError), + + /// 云服务错误 + #[error("Cloud service error: {0}")] + Cloud(String), + + /// 内部错误 + #[error("Internal error: {0}")] + Internal(String), + + /// IO 错误 + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + /// HTTP 错误 + #[error("HTTP error: {0}")] + Http(String), +} + +impl AppError { + /// 创建云服务错误 + pub fn cloud(msg: impl Into) -> Self { + Self::Cloud(msg.into()) + } + + /// 创建内部错误 + pub fn internal(msg: impl Into) -> Self { + Self::Internal(msg.into()) + } + + /// 创建 HTTP 错误 + pub fn http(msg: impl Into) -> Self { + Self::Http(msg.into()) + } + + /// 判断错误是否可恢复 + pub fn is_recoverable(&self) -> bool { + match self { + Self::Config(_) => true, + Self::Proxy(_) => true, + Self::Cloud(_) => true, + Self::Internal(_) => false, + Self::Io(_) => true, + Self::Http(_) => true, + } + } +} + +/// HTTP 响应错误类型 +pub type ApiError = (axum::http::StatusCode, String); + +impl From for ApiError { + fn from(err: AppError) -> Self { + match err { + AppError::Config(_) => (axum::http::StatusCode::BAD_REQUEST, err.to_string()), + AppError::Proxy(_) => (axum::http::StatusCode::BAD_REQUEST, err.to_string()), + AppError::Cloud(_) => (axum::http::StatusCode::SERVICE_UNAVAILABLE, err.to_string()), + AppError::Internal(_) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), + AppError::Io(_) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), + AppError::Http(_) => (axum::http::StatusCode::BAD_GATEWAY, err.to_string()), + } + } +} diff --git a/crates/trusted-proxies/src/error/proxy.rs b/crates/trusted-proxies/src/error/proxy.rs index 6238cfff..9b01ce2e 100644 --- a/crates/trusted-proxies/src/error/proxy.rs +++ b/crates/trusted-proxies/src/error/proxy.rs @@ -11,3 +11,103 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Proxy validation error types + +use std::net::AddrParseError; + +/// 代理验证错误类型 +#[derive(Debug, thiserror::Error)] +pub enum ProxyError { + /// 无效的 X-Forwarded-For 头部 + #[error("Invalid X-Forwarded-For header: {0}")] + InvalidXForwardedFor(String), + + /// 无效的 Forwarded 头部(RFC 7239) + #[error("Invalid Forwarded header (RFC 7239): {0}")] + InvalidForwardedHeader(String), + + /// 代理链验证失败 + #[error("Proxy chain validation failed: {0}")] + ChainValidationFailed(String), + + /// 代理链过长 + #[error("Proxy chain too long: {0} hops (max: {1})")] + ChainTooLong(usize, usize), + + /// 来自不可信代理 + #[error("Request from untrusted proxy: {0}")] + UntrustedProxy(String), + + /// 代理链不连续 + #[error("Proxy chain is not continuous")] + ChainNotContinuous, + + /// IP 地址解析失败 + #[error("Failed to parse IP address: {0}")] + IpParseError(String), + + /// 头部解析失败 + #[error("Failed to parse header: {0}")] + HeaderParseError(String), + + /// 验证超时 + #[error("Validation timeout")] + Timeout, + + /// 内部验证错误 + #[error("Internal validation error: {0}")] + Internal(String), +} + +impl From for ProxyError { + fn from(err: AddrParseError) -> Self { + Self::IpParseError(err.to_string()) + } +} + +impl ProxyError { + /// 创建无效 X-Forwarded-For 头部错误 + pub fn invalid_xff(msg: impl Into) -> Self { + Self::InvalidXForwardedFor(msg.into()) + } + + /// 创建无效 Forwarded 头部错误 + pub fn invalid_forwarded(msg: impl Into) -> Self { + Self::InvalidForwardedHeader(msg.into()) + } + + /// 创建代理链验证失败错误 + pub fn chain_failed(msg: impl Into) -> Self { + Self::ChainValidationFailed(msg.into()) + } + + /// 创建来自不可信代理错误 + pub fn untrusted(proxy: impl Into) -> Self { + Self::UntrustedProxy(proxy.into()) + } + + /// 创建内部验证错误 + pub fn internal(msg: impl Into) -> Self { + Self::Internal(msg.into()) + } + + /// 判断错误是否可恢复(是否应该继续处理请求) + pub fn is_recoverable(&self) -> bool { + match self { + // 这些错误通常意味着我们应该拒绝请求或使用备用 IP + Self::UntrustedProxy(_) => true, + Self::ChainTooLong(_, _) => true, + Self::ChainNotContinuous => true, + + // 这些错误可能意味着配置问题或恶意请求 + Self::InvalidXForwardedFor(_) => false, + Self::InvalidForwardedHeader(_) => false, + Self::ChainValidationFailed(_) => false, + Self::IpParseError(_) => false, + Self::HeaderParseError(_) => false, + Self::Timeout => true, + Self::Internal(_) => false, + } + } +} diff --git a/crates/trusted-proxies/src/lib.rs b/crates/trusted-proxies/src/lib.rs index f8c14a2f..da8e80df 100644 --- a/crates/trusted-proxies/src/lib.rs +++ b/crates/trusted-proxies/src/lib.rs @@ -15,10 +15,11 @@ mod api; mod cloud; mod config; -mod errors; +mod error; mod logging; mod middleware; mod proxy; +mod state; mod utils; pub use cloud::*; diff --git a/crates/trusted-proxies/src/logging/middleware.rs b/crates/trusted-proxies/src/logging/middleware.rs index 6238cfff..805b75f1 100644 --- a/crates/trusted-proxies/src/logging/middleware.rs +++ b/crates/trusted-proxies/src/logging/middleware.rs @@ -11,3 +11,179 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Logging middleware for Axum + +use std::task::{Context, Poll}; +use std::time::Instant; +use tower::Service; +use uuid::Uuid; + +use crate::logging::Logger; + +/// 请求日志中间件层 +#[derive(Clone)] +pub struct RequestLoggingLayer { + logger: Logger, +} + +impl RequestLoggingLayer { + /// 创建新的日志中间件层 + pub fn new(logger: Logger) -> Self { + Self { logger } + } +} + +impl tower::Layer for RequestLoggingLayer { + type Service = RequestLoggingMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + RequestLoggingMiddleware { + inner, + logger: self.logger.clone(), + } + } +} + +/// 请求日志中间件服务 +#[derive(Clone)] +pub struct RequestLoggingMiddleware { + inner: S, + logger: Logger, +} + +impl Service for RequestLoggingMiddleware +where + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: axum::extract::Request) -> Self::Future { + let logger = self.logger.clone(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + // 生成请求 ID + let request_id = Uuid::new_v4().to_string(); + + // 记录请求开始时间和日志 + let start_time = Instant::now(); + logger.log_request(&req, &request_id); + + // 将请求 ID 添加到请求扩展中 + let mut req = req; + req.extensions_mut().insert(RequestId(request_id.clone())); + + // 处理请求 + let result = inner.call(req).await; + + // 计算处理时间 + let duration = start_time.elapsed(); + + // 记录响应 + match &result { + Ok(response) => { + logger.log_response(response, &request_id, duration); + } + Err(error) => { + logger.log_error(error, Some(&request_id)); + } + } + + result + }) + } +} + +/// 请求 ID 包装器 +#[derive(Debug, Clone)] +pub struct RequestId(String); + +impl RequestId { + /// 获取请求 ID + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for RequestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// 代理特定的日志中间件 +#[derive(Clone)] +pub struct ProxyLoggingMiddleware { + inner: S, + logger: Logger, +} + +impl ProxyLoggingMiddleware { + /// 创建新的代理日志中间件 + pub fn new(inner: S, logger: Logger) -> Self { + Self { inner, logger } + } +} + +impl Service for ProxyLoggingMiddleware +where + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: axum::extract::Request) -> Self::Future { + // 记录代理相关信息 + let peer_addr = req.extensions().get::().copied(); + let client_info = req.extensions().get::(); + + if let (Some(addr), Some(info)) = (peer_addr, client_info) { + self.logger + .log_info(&format!("Proxy request from {}: {}", addr, info.to_log_string()), None); + + // 如果有警告,记录它们 + if !info.warnings.is_empty() { + for warning in &info.warnings { + self.logger.log_warning(warning, Some("proxy_validation")); + } + } + } + + self.inner.call(req) + } +} + +/// 代理日志中间件层 +#[derive(Clone)] +pub struct ProxyLoggingLayer { + logger: Logger, +} + +impl ProxyLoggingLayer { + /// 创建新的代理日志中间件层 + pub fn new(logger: Logger) -> Self { + Self { logger } + } +} + +impl tower::Layer for ProxyLoggingLayer { + type Service = ProxyLoggingMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + ProxyLoggingMiddleware::new(inner, self.logger.clone()) + } +} diff --git a/crates/trusted-proxies/src/logging/mod.rs b/crates/trusted-proxies/src/logging/mod.rs index 4838978f..92dd89bb 100644 --- a/crates/trusted-proxies/src/logging/mod.rs +++ b/crates/trusted-proxies/src/logging/mod.rs @@ -12,4 +12,208 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Logging module for structured logging and middleware + mod middleware; + +pub use middleware::*; + +/// 日志配置 +#[derive(Debug, Clone)] +pub struct LoggingConfig { + /// 是否启用结构化日志 + pub structured: bool, + /// 日志级别 + pub level: String, + /// 是否启用请求 ID + pub enable_request_id: bool, + /// 是否记录请求体 + pub log_request_body: bool, + /// 是否记录响应体 + pub log_response_body: bool, + /// 敏感字段列表(将被脱敏) + pub sensitive_fields: Vec, +} + +impl Default for LoggingConfig { + fn default() -> Self { + Self { + structured: false, + level: "info".to_string(), + enable_request_id: true, + log_request_body: false, + log_response_body: false, + sensitive_fields: vec![ + "password".to_string(), + "token".to_string(), + "secret".to_string(), + "authorization".to_string(), + ], + } + } +} + +/// 初始化日志系统 +pub fn init_logging(config: &LoggingConfig) -> Result<(), Box> { + // 创建日志过滤器 + let filter = tracing_subscriber::EnvFilter::builder() + .with_default_directive(config.level.parse().unwrap_or(tracing::Level::INFO.into())) + .from_env_lossy(); + + // 根据配置选择日志格式 + if config.structured { + // 结构化日志(JSON 格式) + tracing_subscriber::fmt() + .json() + .with_env_filter(filter) + .with_target(true) + .with_thread_ids(true) + .with_thread_names(true) + .with_file(true) + .with_line_number(true) + .init(); + } else { + // 普通文本日志 + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(true) + .with_thread_ids(true) + .with_thread_names(true) + .with_file(true) + .with_line_number(true) + .init(); + } + + tracing::info!("Logging initialized with level: {}", config.level); + + Ok(()) +} + +/// 日志记录器 +#[derive(Debug, Clone)] +pub struct Logger { + config: LoggingConfig, +} + +impl Logger { + /// 创建新的日志记录器 + pub fn new(config: LoggingConfig) -> Self { + Self { config } + } + + /// 记录 HTTP 请求 + pub fn log_request(&self, req: &axum::http::Request, request_id: &str) { + let method = req.method(); + let uri = req.uri(); + let version = req.version(); + + tracing::info!( + request.method = %method, + request.uri = %uri, + request.version = ?version, + request_id = %request_id, + "HTTP request received" + ); + + // 如果启用了请求体日志记录,记录头部 + if self.config.log_request_body { + self.log_headers(req.headers(), "request"); + } + } + + /// 记录 HTTP 响应 + pub fn log_response(&self, res: &axum::http::Response, request_id: &str, duration: std::time::Duration) { + let status = res.status(); + let version = res.version(); + + tracing::info!( + response.status = %status, + response.version = ?version, + request_id = %request_id, + duration_ms = duration.as_millis(), + "HTTP response sent" + ); + + // 如果启用了响应体日志记录,记录头部 + if self.config.log_response_body { + self.log_headers(res.headers(), "response"); + } + } + + /// 记录头部信息(脱敏敏感字段) + fn log_headers(&self, headers: &axum::http::HeaderMap, header_type: &str) { + let mut header_fields = std::collections::HashMap::new(); + + for (name, value) in headers { + let name_str = name.to_string(); + let value_str = match value.to_str() { + Ok(s) => s.to_string(), + Err(_) => "[BINARY]".to_string(), + }; + + // 检查是否为敏感字段 + let is_sensitive = self + .config + .sensitive_fields + .iter() + .any(|field| name_str.to_lowercase().contains(&field.to_lowercase())); + + if is_sensitive { + header_fields.insert(name_str, "[REDACTED]".to_string()); + } else { + header_fields.insert(name_str, value_str); + } + } + + tracing::debug!( + headers = ?header_fields, + header_type = header_type, + "HTTP headers" + ); + } + + /// 记录错误 + pub fn log_error(&self, error: &impl std::error::Error, request_id: Option<&str>) { + if let Some(id) = request_id { + tracing::error!( + error = %error, + error.type = std::any::type_name_of_val(error), + request_id = %id, + "Request error" + ); + } else { + tracing::error!( + error = %error, + error.type = std::any::type_name_of_val(error), + "Application error" + ); + } + } + + /// 记录警告 + pub fn log_warning(&self, message: &str, context: Option<&str>) { + if let Some(ctx) = context { + tracing::warn!(message = %message, context = %ctx, "Warning"); + } else { + tracing::warn!(message = %message, "Warning"); + } + } + + /// 记录信息 + pub fn log_info(&self, message: &str, context: Option<&str>) { + if let Some(ctx) = context { + tracing::info!(message = %message, context = %ctx, "Info"); + } else { + tracing::info!(message = %message, "Info"); + } + } + + /// 记录调试信息 + pub fn log_debug(&self, message: &str, context: Option<&str>) { + if let Some(ctx) = context { + tracing::debug!(message = %message, context = %ctx, "Debug"); + } else { + tracing::debug!(message = %message, "Debug"); + } + } +} diff --git a/crates/trusted-proxies/src/main.rs b/crates/trusted-proxies/src/main.rs index b4a0b3d4..47b8c574 100644 --- a/crates/trusted-proxies/src/main.rs +++ b/crates/trusted-proxies/src/main.rs @@ -12,4 +12,125 @@ // See the License for the specific language governing permissions and // limitations under the License. -fn main() {} +//! Main application entry point for the trusted proxy system + +use std::net::SocketAddr; +use std::sync::Arc; + +use axum::{ + extract::State, + response::{IntoResponse, Json}, + routing::get, + Router, +}; +use tokio::net::TcpListener; +use tracing::{error, info}; + +mod api; +mod cloud; +mod config; +mod error; +mod middleware; +mod proxy; +mod state; +mod utils; + +use api::handlers; +use config::{AppConfig, ConfigLoader}; +use error::AppError; +use middleware::TrustedProxyLayer; +use proxy::metrics::{default_proxy_metrics, ProxyMetrics}; +use state::AppState; + +#[tokio::main] +async fn main() -> Result<(), AppError> { + // 加载环境变量 + dotenvy::dotenv().ok(); + + // 从环境变量加载配置 + let config = ConfigLoader::from_env_or_default(); + + // 初始化日志 + init_logging(&config.monitoring)?; + + // 打印配置摘要 + ConfigLoader::print_summary(&config); + + // 初始化指标收集器 + let metrics = if config.monitoring.metrics_enabled { + let metrics = default_proxy_metrics(true); + metrics.print_summary(); + Some(metrics) + } else { + None + }; + + // 创建应用状态 + let state = AppState { + config: Arc::new(config), + metrics: metrics.clone(), + }; + + // 创建可信代理中间件层 + let proxy_layer = TrustedProxyLayer::enabled(state.clone().config.proxy.clone(), metrics); + + // 创建路由 + let app = Router::new() + // 健康检查端点 + .route("/health", get(handlers::health_check)) + // 配置查看端点 + .route("/config", get(handlers::show_config)) + // 客户端信息端点 + .route("/client-info", get(handlers::client_info)) + // 代理测试端点 + .route("/proxy-test", get(handlers::proxy_test)) + // 指标端点(如果启用) + .route("/metrics", get(handlers::metrics)) + // 添加应用状态 + .with_state(state.clone()) + // 添加可信代理中间件 + .layer(proxy_layer) + // 添加追踪中间件(如果启用) + .layer(tower_http::trace::TraceLayer::new_for_http()) + // 添加 CORS 中间件 + .layer(tower_http::cors::CorsLayer::permissive()) + // 添加压缩中间件 + .layer(tower_http::compression::CompressionLayer::new()); + + // 启动服务器 + let addr = state.config.server_addr; + let listener = TcpListener::bind(addr).await.map_err(|e| AppError::Io(e))?; + + info!("Server listening on http://{}", addr); + info!("Available endpoints:"); + info!(" GET /health - Health check"); + info!(" GET /config - Show configuration"); + info!(" GET /client-info - Show client information"); + info!(" GET /proxy-test - Test proxy headers"); + info!(" GET /metrics - Prometheus metrics (if enabled)"); + + axum::serve(listener, app).await.map_err(|e| AppError::Io(e))?; + + Ok(()) +} + +/// 初始化日志系统 +fn init_logging(monitoring_config: &config::MonitoringConfig) -> Result<(), AppError> { + // 创建日志过滤器 + let filter = tracing_subscriber::EnvFilter::builder() + .with_default_directive(monitoring_config.log_level.parse().unwrap_or(tracing::Level::INFO.into())) + .from_env_lossy(); + + // 根据配置选择日志格式 + if monitoring_config.structured_logging { + // 结构化日志(JSON 格式) + tracing_subscriber::fmt().json().with_env_filter(filter).init(); + } else { + // 普通文本日志 + tracing_subscriber::fmt().with_env_filter(filter).init(); + } + + info!("Logging initialized with level: {}", monitoring_config.log_level); + + Ok(()) +} diff --git a/crates/trusted-proxies/src/middleware/layer.rs b/crates/trusted-proxies/src/middleware/layer.rs index 6238cfff..6c74f11f 100644 --- a/crates/trusted-proxies/src/middleware/layer.rs +++ b/crates/trusted-proxies/src/middleware/layer.rs @@ -11,3 +11,65 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Tower layer implementation for trusted proxy middleware + +use std::sync::Arc; +use tower::Layer; + +use crate::config::TrustedProxyConfig; +use crate::middleware::TrustedProxyMiddleware; +use crate::proxy::ProxyMetrics; +use crate::proxy::ProxyValidator; + +/// 可信代理中间件层 +#[derive(Clone)] +pub struct TrustedProxyLayer { + /// 代理验证器 + pub(crate) validator: Arc, + /// 是否启用中间件 + pub(crate) enabled: bool, +} + +impl TrustedProxyLayer { + /// 创建新的中间件层 + pub fn new(config: TrustedProxyConfig, metrics: Option, enabled: bool) -> Self { + let validator = ProxyValidator::new(config, metrics); + + Self { + validator: Arc::new(validator), + enabled, + } + } + + /// 创建启用的中间件层 + pub fn enabled(config: TrustedProxyConfig, metrics: Option) -> Self { + Self::new(config, metrics, true) + } + + /// 创建禁用的中间件层 + pub fn disabled() -> Self { + Self::new( + TrustedProxyConfig::new(Vec::new(), crate::config::ValidationMode::Lenient, true, 10, true, Vec::new()), + None, + false, + ) + } + + /// 检查中间件是否启用 + pub fn is_enabled(&self) -> bool { + self.enabled + } +} + +impl Layer for TrustedProxyLayer { + type Service = TrustedProxyMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + TrustedProxyMiddleware { + inner, + validator: self.validator.clone(), + enabled: self.enabled, + } + } +} diff --git a/crates/trusted-proxies/src/middleware/mod.rs b/crates/trusted-proxies/src/middleware/mod.rs index 4d4cdab0..e9468529 100644 --- a/crates/trusted-proxies/src/middleware/mod.rs +++ b/crates/trusted-proxies/src/middleware/mod.rs @@ -12,5 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Middleware module for Axum web framework + mod layer; mod service; + +pub use layer::*; +pub use service::*; + +// Re-export commonly used types +pub use crate::proxy::ClientInfo; diff --git a/crates/trusted-proxies/src/middleware/service.rs b/crates/trusted-proxies/src/middleware/service.rs index 6238cfff..dbbe586a 100644 --- a/crates/trusted-proxies/src/middleware/service.rs +++ b/crates/trusted-proxies/src/middleware/service.rs @@ -11,3 +11,125 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Tower service implementation for trusted proxy middleware + +use std::sync::Arc; +use std::task::{ready, Context, Poll}; + +use axum::extract::Request; +use axum::response::Response; +use tower::Service; +use tracing::{debug, instrument, Span}; + +use crate::error::ProxyError; +use crate::middleware::layer::TrustedProxyLayer; +use crate::proxy::{ClientInfo, ProxyValidator}; + +/// 可信代理中间件服务 +#[derive(Clone)] +pub struct TrustedProxyMiddleware { + /// 内部服务 + inner: S, + /// 代理验证器 + validator: Arc, + /// 是否启用中间件 + enabled: bool, +} + +impl TrustedProxyMiddleware { + /// 创建新的中间件服务 + pub fn new(inner: S, validator: Arc, enabled: bool) -> Self { + Self { + inner, + validator, + enabled, + } + } + + /// 从层创建中间件服务 + pub fn from_layer(inner: S, layer: &TrustedProxyLayer) -> Self { + Self::new(inner, layer.validator.clone(), layer.enabled) + } +} + +impl Service for TrustedProxyMiddleware +where + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + #[instrument( + name = "trusted_proxy_middleware", + skip_all, + fields( + http.method = %req.method(), + http.uri = %req.uri(), + http.version = ?req.version(), + enabled = self.enabled, + ) + )] + fn call(&mut self, mut req: Request) -> Self::Future { + let span = Span::current(); + + // 如果中间件未启用,直接传递请求 + if !self.enabled { + debug!("Trusted proxy middleware is disabled"); + return self.inner.call(req); + } + + // 记录请求开始时间 + let start_time = std::time::Instant::now(); + + // 提取对端地址 + let peer_addr = req.extensions().get::().copied(); + + // 为 span 添加字段 + if let Some(addr) = peer_addr { + span.record("peer.addr", addr.to_string()); + } + + // 验证请求并提取客户端信息 + match self.validator.validate_request(peer_addr, req.headers()) { + Ok(client_info) => { + // 记录客户端信息到 span + span.record("client.ip", client_info.real_ip.to_string()); + span.record("client.trusted", client_info.is_from_trusted_proxy); + span.record("client.hops", client_info.proxy_hops as i64); + + // 将客户端信息存入请求扩展 + req.extensions_mut().insert(client_info); + + // 记录验证成功 + let duration = start_time.elapsed(); + debug!("Proxy validation successful in {:?}", duration); + } + Err(err) => { + // 记录验证失败 + span.record("error", true); + span.record("error.message", err.to_string()); + + // 如果是可恢复的错误,创建默认的客户端信息 + if err.is_recoverable() { + let client_info = ClientInfo::direct( + peer_addr.unwrap_or_else(|| std::net::SocketAddr::new(std::net::IpAddr::from([0, 0, 0, 0]), 0)), + ); + req.extensions_mut().insert(client_info); + } else { + // 对于不可恢复的错误,记录警告 + debug!("Unrecoverable proxy validation error: {}", err); + } + } + } + + // 调用内部服务 + self.inner.call(req) + } +} diff --git a/crates/trusted-proxies/src/proxy/cache.rs b/crates/trusted-proxies/src/proxy/cache.rs index 6238cfff..07ae6fc5 100644 --- a/crates/trusted-proxies/src/proxy/cache.rs +++ b/crates/trusted-proxies/src/proxy/cache.rs @@ -11,3 +11,227 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Cache implementation for proxy validation + +use metrics::{counter, gauge, histogram}; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +/// 缓存条目 +#[derive(Debug, Clone)] +struct CacheEntry { + /// 是否可信 + is_trusted: bool, + /// 缓存时间 + cached_at: Instant, + /// 过期时间 + expires_at: Instant, +} + +/// IP 验证缓存 +#[derive(Debug, Clone)] +pub struct IpValidationCache { + /// 缓存存储 + cache: Arc>>, + /// 最大容量 + capacity: usize, + /// 默认 TTL + default_ttl: Duration, + /// 是否启用 + enabled: bool, +} + +impl IpValidationCache { + /// 创建新的缓存 + pub fn new(capacity: usize, default_ttl: Duration, enabled: bool) -> Self { + Self { + cache: Arc::new(RwLock::new(HashMap::with_capacity(capacity))), + capacity, + default_ttl, + enabled, + } + } + + /// 检查 IP 是否可信(带缓存) + pub fn is_trusted(&self, ip: &IpAddr, validator: impl FnOnce(&IpAddr) -> bool) -> bool { + // 如果缓存未启用,直接验证 + if !self.enabled { + return validator(ip); + } + + let now = Instant::now(); + + // 检查缓存 + { + let cache = self.cache.read(); + if let Some(entry) = cache.get(ip) { + if now < entry.expires_at { + // 缓存命中 + counter!("proxy.cache.hits").increment(1); + return entry.is_trusted; + } + } + } + + // 缓存未命中 + counter!("proxy.cache.misses").increment(1); + + // 验证 IP + let is_trusted = validator(ip); + + // 更新缓存 + self.update_cache(*ip, is_trusted, now); + + is_trusted + } + + /// 更新缓存 + fn update_cache(&self, ip: IpAddr, is_trusted: bool, now: Instant) { + let mut cache = self.cache.write(); + + // 检查是否需要清理(如果达到容量限制) + if cache.len() >= self.capacity { + self.cleanup_expired(&mut cache, now); + + // 如果仍然满,删除最旧的条目 + if cache.len() >= self.capacity { + self.evict_oldest(&mut cache); + } + } + + // 添加新条目 + let entry = CacheEntry { + is_trusted, + cached_at: now, + expires_at: now + self.default_ttl, + }; + + cache.insert(ip, entry); + + // 更新指标 + gauge!("proxy.cache.size").set(cache.len() as f64); + } + + /// 清理过期条目 + fn cleanup_expired(&self, cache: &mut HashMap, now: Instant) { + let expired_keys: Vec<_> = cache + .iter() + .filter(|(_, entry)| now >= entry.expires_at) + .map(|(ip, _)| *ip) + .collect(); + + for key in expired_keys.clone() { + cache.remove(&key); + } + + if !expired_keys.is_empty() { + counter!("proxy.cache.evictions").increment(expired_keys.len() as u64); + } + } + + /// 淘汰最旧的条目 + fn evict_oldest(&self, cache: &mut HashMap) { + if let Some(oldest_key) = cache.iter().min_by_key(|(_, entry)| entry.cached_at).map(|(ip, _)| *ip) { + cache.remove(&oldest_key); + counter!("proxy.cache.evictions").increment(1); + } + } + + /// 清空缓存 + pub fn clear(&self) { + let mut cache = self.cache.write(); + cache.clear(); + gauge!("proxy.cache.size").set(0.00); + } + + /// 获取缓存统计信息 + pub fn stats(&self) -> CacheStats { + let cache = self.cache.read(); + + let mut oldest = Instant::now(); + let mut newest = Instant::now(); + let mut expired_count = 0; + + let now = Instant::now(); + for entry in cache.values() { + if entry.cached_at < oldest { + oldest = entry.cached_at; + } + if entry.cached_at > newest { + newest = entry.cached_at; + } + if now >= entry.expires_at { + expired_count += 1; + } + } + + CacheStats { + size: cache.len(), + capacity: self.capacity, + expired_count, + oldest_age: now.duration_since(oldest), + newest_age: now.duration_since(newest), + } + } + + /// 定期清理任务 + pub async fn cleanup_task(&self, interval: Duration) { + let mut interval_timer = tokio::time::interval(interval); + + loop { + interval_timer.tick().await; + self.cleanup(); + } + } + + /// 执行清理 + fn cleanup(&self) { + let now = Instant::now(); + let mut cache = self.cache.write(); + self.cleanup_expired(&mut cache, now); + + // 记录清理后的指标 + gauge!("proxy.cache.size").set(cache.len() as f64); + } +} + +/// 缓存统计信息 +#[derive(Debug, Clone)] +pub struct CacheStats { + /// 当前缓存大小 + pub size: usize, + /// 缓存容量 + pub capacity: usize, + /// 过期条目数量 + pub expired_count: usize, + /// 最旧条目的年龄 + pub oldest_age: Duration, + /// 最新条目的年龄 + pub newest_age: Duration, +} + +impl CacheStats { + /// 获取缓存使用率 + pub fn usage_percentage(&self) -> f64 { + if self.capacity == 0 { + 0.0 + } else { + (self.size as f64 / self.capacity as f64) * 100.0 + } + } + + /// 获取命中率(需要外部跟踪命中/未命中) + pub fn hit_rate(&self, hits: u64, misses: u64) -> f64 { + let total = hits + misses; + if total == 0 { + 0.0 + } else { + (hits as f64 / total as f64) * 100.0 + } + } +} diff --git a/crates/trusted-proxies/src/proxy/chain.rs b/crates/trusted-proxies/src/proxy/chain.rs index 6238cfff..92a5edee 100644 --- a/crates/trusted-proxies/src/proxy/chain.rs +++ b/crates/trusted-proxies/src/proxy/chain.rs @@ -11,3 +11,275 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Proxy chain analysis and validation + +use std::collections::HashSet; +use std::net::IpAddr; + +use axum::http::HeaderMap; +use tracing::{debug, trace}; + +use crate::config::{TrustedProxyConfig, ValidationMode}; +use crate::error::ProxyError; +use crate::utils::ip::is_valid_ip_address; + +/// 代理链分析结果 +#[derive(Debug, Clone)] +pub struct ChainAnalysis { + /// 客户端真实 IP + pub client_ip: IpAddr, + /// 已验证的代理跳数 + pub hops: usize, + /// 是否连续 + pub is_continuous: bool, + /// 警告信息 + pub warnings: Vec, + /// 使用的验证模式 + pub validation_mode: ValidationMode, + /// 可信代理部分 + pub trusted_chain: Vec, +} + +/// 代理链分析器 +#[derive(Debug, Clone)] +pub struct ProxyChainAnalyzer { + /// 代理配置 + config: TrustedProxyConfig, + /// 已验证的可信代理 IP 缓存(用于快速查找) + trusted_ip_cache: HashSet, +} + +impl ProxyChainAnalyzer { + /// 创建新的代理链分析器 + pub fn new(config: TrustedProxyConfig) -> Self { + // 构建可信 IP 缓存 + let mut trusted_ip_cache = HashSet::new(); + + for proxy in &config.proxies { + match proxy { + crate::config::TrustedProxy::Single(ip) => { + trusted_ip_cache.insert(*ip); + } + crate::config::TrustedProxy::Cidr(network) => { + // 对于小网络,可以缓存所有 IP + // 这里我们只缓存/24及更小的网络的前几个IP作为示例 + // 实际生产环境中可能需要更复杂的缓存策略 + if network.prefix() >= 24 && network.prefix() <= 30 { + // 对于小网络,我们可以缓存网络地址和广播地址之间的几个 IP + // 这里简化处理,只缓存网络地址 + if let Some(first_ip) = network.iter().next() { + trusted_ip_cache.insert(first_ip); + } + } + } + } + } + + Self { + config, + trusted_ip_cache, + } + } + + /// 分析代理链 + pub fn analyze_chain( + &self, + proxy_chain: &[IpAddr], + current_proxy_ip: IpAddr, + headers: &HeaderMap, + ) -> Result { + trace!("Analyzing proxy chain: {:?} with current proxy: {}", proxy_chain, current_proxy_ip); + + // 验证 IP 地址 + self.validate_ip_addresses(proxy_chain)?; + + // 构建完整链(包括当前代理) + let mut full_chain = proxy_chain.to_vec(); + full_chain.push(current_proxy_ip); + + // 根据验证模式分析链 + let (client_ip, trusted_chain, hops) = match self.config.validation_mode { + ValidationMode::Lenient => self.analyze_lenient(&full_chain), + ValidationMode::Strict => self.analyze_strict(&full_chain)?, + ValidationMode::HopByHop => self.analyze_hop_by_hop(&full_chain), + }; + + // 检查链连续性 + let is_continuous = if self.config.enable_chain_continuity_check { + self.check_chain_continuity(&full_chain, &trusted_chain) + } else { + true + }; + + // 收集警告 + let warnings = self.collect_warnings(&full_chain, &trusted_chain, headers); + + // 验证客户端 IP + if !is_valid_ip_address(&client_ip) { + return Err(ProxyError::internal(format!("Invalid client IP: {}", client_ip))); + } + + Ok(ChainAnalysis { + client_ip, + hops, + is_continuous, + warnings, + validation_mode: self.config.validation_mode, + trusted_chain, + }) + } + + /// 宽松模式分析:只要最后一个代理可信,就接受整个链 + fn analyze_lenient(&self, chain: &[IpAddr]) -> (IpAddr, Vec, usize) { + if chain.is_empty() { + return (IpAddr::from([0, 0, 0, 0]), Vec::new(), 0); + } + + // 检查最后一个代理是否可信 + if let Some(last_proxy) = chain.last() { + if self.is_ip_trusted(last_proxy) { + // 整个链都可信 + let client_ip = chain.first().copied().unwrap_or(*last_proxy); + return (client_ip, chain.to_vec(), chain.len()); + } + } + + // 如果最后一个代理不可信,使用链中第一个 IP 作为客户端 + let client_ip = chain.first().copied().unwrap_or(IpAddr::from([0, 0, 0, 0])); + (client_ip, Vec::new(), 0) + } + + /// 严格模式分析:要求链中所有代理都可信 + fn analyze_strict(&self, chain: &[IpAddr]) -> Result<(IpAddr, Vec, usize), ProxyError> { + if chain.is_empty() { + return Ok((IpAddr::from([0, 0, 0, 0]), Vec::new(), 0)); + } + + // 检查每个代理是否都可信 + for (i, ip) in chain.iter().enumerate() { + if !self.is_ip_trusted(ip) { + return Err(ProxyError::chain_failed(format!("Proxy at position {} ({}) is not trusted", i, ip))); + } + } + + let client_ip = chain.first().copied().unwrap_or(IpAddr::from([0, 0, 0, 0])); + Ok((client_ip, chain.to_vec(), chain.len())) + } + + /// 跳数模式分析:从右向左找到第一个不可信代理 + fn analyze_hop_by_hop(&self, chain: &[IpAddr]) -> (IpAddr, Vec, usize) { + if chain.is_empty() { + return (IpAddr::from([0, 0, 0, 0]), Vec::new(), 0); + } + + let mut trusted_chain = Vec::new(); + let mut validated_hops = 0; + + // 从右向左遍历(从离我们最近的代理开始) + for ip in chain.iter().rev() { + if self.is_ip_trusted(ip) { + trusted_chain.insert(0, *ip); + validated_hops += 1; + } else { + // 找到第一个不可信代理,停止遍历 + break; + } + } + + if trusted_chain.is_empty() { + // 没有可信代理,使用链的最后一个 IP + let client_ip = *chain.last().unwrap(); + (client_ip, vec![client_ip], 0) + } else { + // 客户端 IP 是可信链的第一个 IP 之前的那个 IP + let client_ip_index = chain.len().saturating_sub(trusted_chain.len()); + let client_ip = if client_ip_index > 0 { + chain[client_ip_index - 1] + } else { + // 如果整个链都可信,使用第一个 IP + chain[0] + }; + + (client_ip, trusted_chain, validated_hops) + } + } + + /// 检查链连续性 + fn check_chain_continuity(&self, full_chain: &[IpAddr], trusted_chain: &[IpAddr]) -> bool { + if full_chain.len() <= 1 || trusted_chain.is_empty() { + return true; + } + + // 可信链应该是完整链的尾部连续部分 + if trusted_chain.len() > full_chain.len() { + return false; + } + + let expected_tail = &full_chain[full_chain.len() - trusted_chain.len()..]; + expected_tail == trusted_chain + } + + /// 验证 IP 地址 + fn validate_ip_addresses(&self, chain: &[IpAddr]) -> Result<(), ProxyError> { + for ip in chain { + if !is_valid_ip_address(ip) { + return Err(ProxyError::IpParseError(format!("Invalid IP address in chain: {}", ip))); + } + + // 检查是否为特殊地址 + if ip.is_unspecified() { + return Err(ProxyError::invalid_xff("IP address cannot be unspecified (0.0.0.0 or ::)")); + } + + if ip.is_multicast() { + return Err(ProxyError::invalid_xff("IP address cannot be multicast")); + } + } + + Ok(()) + } + + /// 检查 IP 是否可信 + fn is_ip_trusted(&self, ip: &IpAddr) -> bool { + // 首先检查缓存 + if self.trusted_ip_cache.contains(ip) { + return true; + } + + // 然后检查配置中的代理 + self.config.proxies.iter().any(|proxy| proxy.contains(ip)) + } + + /// 收集警告信息 + fn collect_warnings(&self, full_chain: &[IpAddr], trusted_chain: &[IpAddr], headers: &HeaderMap) -> Vec { + let mut warnings = Vec::new(); + + // 检查代理链长度 + if full_chain.len() > self.config.max_hops { + warnings.push(format!( + "Proxy chain length ({}) exceeds recommended maximum ({})", + full_chain.len(), + self.config.max_hops + )); + } + + // 检查是否缺少必要的头部 + if trusted_chain.len() > 0 { + if !headers.contains_key("x-forwarded-for") && !headers.contains_key("forwarded") { + warnings.push("No proxy headers found for trusted proxy request".to_string()); + } + } + + // 检查是否有重复的 IP + let mut seen_ips = HashSet::new(); + for ip in full_chain { + if !seen_ips.insert(ip) { + warnings.push(format!("Duplicate IP in proxy chain: {}", ip)); + break; + } + } + + warnings + } +} diff --git a/crates/trusted-proxies/src/proxy/metrics.rs b/crates/trusted-proxies/src/proxy/metrics.rs index 6238cfff..99b5e633 100644 --- a/crates/trusted-proxies/src/proxy/metrics.rs +++ b/crates/trusted-proxies/src/proxy/metrics.rs @@ -11,3 +11,201 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Metrics and monitoring for proxy validation + +use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; +use std::time::Duration; +use tracing::{info, warn}; + +use crate::config::ValidationMode; +use crate::error::ProxyError; + +/// 代理验证指标 +#[derive(Debug, Clone)] +pub struct ProxyMetrics { + /// 是否启用指标 + enabled: bool, + /// 应用名称(用于指标标签) + app_name: String, +} + +impl ProxyMetrics { + /// 创建新的指标收集器 + pub fn new(app_name: &str, enabled: bool) -> Self { + let metrics = Self { + enabled, + app_name: app_name.to_string(), + }; + + // 注册指标描述 + metrics.register_descriptions(); + + metrics + } + + /// 注册指标描述 + fn register_descriptions(&self) { + if !self.enabled { + return; + } + + describe_counter!("proxy_validation_attempts_total", "Total number of proxy validation attempts"); + + describe_counter!("proxy_validation_success_total", "Total number of successful proxy validations"); + + describe_counter!("proxy_validation_failure_total", "Total number of failed proxy validations"); + + describe_counter!( + "proxy_validation_failure_by_type_total", + "Total number of failed proxy validations by error type" + ); + + describe_gauge!("proxy_chain_length", "Length of proxy chains being validated"); + + describe_histogram!("proxy_validation_duration_seconds", "Duration of proxy validation in seconds"); + + describe_gauge!("proxy_cache_size", "Size of the proxy validation cache"); + + describe_counter!("proxy_cache_hits_total", "Total number of cache hits"); + + describe_counter!("proxy_cache_misses_total", "Total number of cache misses"); + } + + /// 记录验证尝试 + pub fn increment_validation_attempts(&self) { + if !self.enabled { + return; + } + + counter!( + "proxy_validation_attempts_total", + 1, + "app" => self.app_name.clone() + ); + } + + /// 记录验证成功 + pub fn record_validation_success(&self, from_trusted_proxy: bool, proxy_hops: usize, duration: Duration) { + if !self.enabled { + return; + } + + counter!( + "proxy_validation_success_total", + 1, + "app" => self.app_name.clone(), + "trusted" => from_trusted_proxy.to_string() + ); + + gauge!( + "proxy_chain_length", + proxy_hops as f64, + "app" => self.app_name.clone() + ); + + histogram!( + "proxy_validation_duration_seconds", + duration.as_secs_f64(), + "app" => self.app_name.clone() + ); + } + + /// 记录验证失败 + pub fn record_validation_failure(&self, error: &ProxyError, duration: Duration) { + if !self.enabled { + return; + } + + let error_type = match error { + ProxyError::InvalidXForwardedFor(_) => "invalid_x_forwarded_for", + ProxyError::InvalidForwardedHeader(_) => "invalid_forwarded_header", + ProxyError::ChainValidationFailed(_) => "chain_validation_failed", + ProxyError::ChainTooLong(_, _) => "chain_too_long", + ProxyError::UntrustedProxy(_) => "untrusted_proxy", + ProxyError::ChainNotContinuous => "chain_not_continuous", + ProxyError::IpParseError(_) => "ip_parse_error", + ProxyError::HeaderParseError(_) => "header_parse_error", + ProxyError::Timeout => "timeout", + ProxyError::Internal(_) => "internal", + }; + + counter!( + "proxy_validation_failure_total", + 1, + "app" => self.app_name.clone(), + "error_type" => error_type + ); + + counter!( + "proxy_validation_failure_by_type_total", + 1, + "app" => self.app_name.clone(), + "error_type" => error_type + ); + + histogram!( + "proxy_validation_duration_seconds", + duration.as_secs_f64(), + "app" => self.app_name.clone(), + "error_type" => error_type + ); + } + + /// 记录验证模式使用情况 + pub fn record_validation_mode(&self, mode: ValidationMode) { + if !self.enabled { + return; + } + + gauge!( + "proxy_validation_mode", + match mode { + ValidationMode::Lenient => 0.0, + ValidationMode::Strict => 1.0, + ValidationMode::HopByHop => 2.0, + }, + "app" => self.app_name.clone(), + "mode" => mode.as_str() + ); + } + + /// 记录缓存指标 + pub fn record_cache_metrics(&self, hits: u64, misses: u64, size: usize) { + if !self.enabled { + return; + } + + counter!("proxy_cache_hits_total", hits, "app" => self.app_name.clone()); + counter!("proxy_cache_misses_total", misses, "app" => self.app_name.clone()); + gauge!("proxy_cache_size", size as f64, "app" => self.app_name.clone()); + } + + /// 打印指标摘要 + pub fn print_summary(&self) { + if !self.enabled { + info!("Metrics collection is disabled"); + return; + } + + info!("Proxy metrics enabled for application: {}", self.app_name); + info!("Available metrics:"); + info!(" - proxy_validation_attempts_total"); + info!(" - proxy_validation_success_total"); + info!(" - proxy_validation_failure_total"); + info!(" - proxy_validation_failure_by_type_total"); + info!(" - proxy_chain_length"); + info!(" - proxy_validation_duration_seconds"); + info!(" - proxy_cache_size"); + info!(" - proxy_cache_hits_total"); + info!(" - proxy_cache_misses_total"); + } +} + +/// 默认应用名称 +const DEFAULT_APP_NAME: &str = "trusted-proxy"; + +/// 创建默认的代理指标收集器 +pub fn default_proxy_metrics(enabled: bool) -> ProxyMetrics { + ProxyMetrics::new(DEFAULT_APP_NAME, enabled) +} diff --git a/crates/trusted-proxies/src/proxy/mod.rs b/crates/trusted-proxies/src/proxy/mod.rs index 51e96a20..e4bf1918 100644 --- a/crates/trusted-proxies/src/proxy/mod.rs +++ b/crates/trusted-proxies/src/proxy/mod.rs @@ -12,6 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Core proxy handling module +//! +//! This module contains the main logic for validating and processing +//! requests through trusted proxies. + mod cache; mod chain; mod metrics; +mod validator; + +pub use cache::*; +pub use chain::*; +pub use metrics::*; +pub use validator::*; + +// Re-export commonly used types +pub use crate::config::{TrustedProxyConfig, ValidationMode}; +pub use crate::error::ProxyError; diff --git a/crates/trusted-proxies/src/proxy/validator.rs b/crates/trusted-proxies/src/proxy/validator.rs index 6238cfff..0dfc279e 100644 --- a/crates/trusted-proxies/src/proxy/validator.rs +++ b/crates/trusted-proxies/src/proxy/validator.rs @@ -11,3 +11,326 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Proxy validator for validating proxy chains and client information + +use std::net::{IpAddr, SocketAddr}; +use std::time::Instant; + +use axum::http::HeaderMap; +use tracing::{debug, trace, warn}; + +use crate::config::{TrustedProxyConfig, ValidationMode}; +use crate::error::ProxyError; +use crate::proxy::chain::ProxyChainAnalyzer; +use crate::proxy::metrics::ProxyMetrics; + +/// 客户端信息验证结果 +#[derive(Debug, Clone)] +pub struct ClientInfo { + /// 真实客户端 IP 地址(已验证) + pub real_ip: IpAddr, + /// 原始请求主机名(如果来自可信代理) + pub forwarded_host: Option, + /// 原始请求协议(如果来自可信代理) + pub forwarded_proto: Option, + /// 请求是否来自可信代理 + pub is_from_trusted_proxy: bool, + /// 直接连接的代理 IP(如果经过代理) + pub proxy_ip: Option, + /// 代理链长度 + pub proxy_hops: usize, + /// 验证模式 + pub validation_mode: ValidationMode, + /// 验证警告信息 + pub warnings: Vec, +} + +impl ClientInfo { + /// 创建直接连接的客户端信息(无代理) + pub fn direct(addr: SocketAddr) -> Self { + Self { + real_ip: addr.ip(), + forwarded_host: None, + forwarded_proto: None, + is_from_trusted_proxy: false, + proxy_ip: None, + proxy_hops: 0, + validation_mode: ValidationMode::Lenient, + warnings: Vec::new(), + } + } + + /// 从可信代理创建客户端信息 + pub fn from_trusted_proxy( + real_ip: IpAddr, + forwarded_host: Option, + forwarded_proto: Option, + proxy_ip: IpAddr, + proxy_hops: usize, + validation_mode: ValidationMode, + warnings: Vec, + ) -> Self { + Self { + real_ip, + forwarded_host, + forwarded_proto, + is_from_trusted_proxy: true, + proxy_ip: Some(proxy_ip), + proxy_hops, + validation_mode, + warnings, + } + } + + /// 获取客户端信息的字符串表示(用于日志) + pub fn to_log_string(&self) -> String { + format!( + "client_ip={}, proxy={:?}, hops={}, trusted={}, mode={:?}", + self.real_ip, self.proxy_ip, self.proxy_hops, self.is_from_trusted_proxy, self.validation_mode + ) + } +} + +/// 代理验证器 +#[derive(Debug, Clone)] +pub struct ProxyValidator { + /// 代理配置 + config: TrustedProxyConfig, + /// 代理链分析器 + chain_analyzer: ProxyChainAnalyzer, + /// 监控指标 + metrics: Option, +} + +impl ProxyValidator { + /// 创建新的代理验证器 + pub fn new(config: TrustedProxyConfig, metrics: Option) -> Self { + let chain_analyzer = ProxyChainAnalyzer::new(config.clone()); + + Self { + config, + chain_analyzer, + metrics, + } + } + + /// 验证请求并提取客户端信息 + pub fn validate_request(&self, peer_addr: Option, headers: &HeaderMap) -> Result { + let start_time = Instant::now(); + + // 记录验证开始 + self.record_metric_start(); + + // 验证请求 + let result = self.validate_request_internal(peer_addr, headers); + + // 记录验证结果 + let duration = start_time.elapsed(); + self.record_metric_result(&result, duration); + + result + } + + /// 内部验证逻辑 + fn validate_request_internal(&self, peer_addr: Option, headers: &HeaderMap) -> Result { + // 如果没有对端地址,使用默认值 + let peer_addr = peer_addr.unwrap_or_else(|| SocketAddr::new(IpAddr::from([0, 0, 0, 0]), 0)); + + // 检查是否来自可信代理 + if self.config.is_trusted(&peer_addr) { + debug!("Request from trusted proxy: {}", peer_addr.ip()); + + // 来自可信代理,解析转发头部 + self.validate_trusted_proxy_request(&peer_addr, headers) + } else { + // 检查是否为私有网络地址 + if self.config.is_private_network(&peer_addr.ip()) { + warn!( + "Request from private network but not trusted: {}. This might be a configuration issue.", + peer_addr.ip() + ); + } + + // 来自不可信代理或直接连接 + Ok(ClientInfo::direct(peer_addr)) + } + } + + /// 验证来自可信代理的请求 + fn validate_trusted_proxy_request(&self, proxy_addr: &SocketAddr, headers: &HeaderMap) -> Result { + let proxy_ip = proxy_addr.ip(); + + // 优先使用 RFC 7239 Forwarded 头部(如果启用) + let client_info = if self.config.enable_rfc7239 { + self.try_parse_rfc7239_headers(headers, proxy_ip) + .unwrap_or_else(|| self.parse_legacy_headers(headers, proxy_ip)) + } else { + self.parse_legacy_headers(headers, proxy_ip) + }; + + // 验证代理链 + let chain_analysis = self + .chain_analyzer + .analyze_chain(&client_info.proxy_chain, proxy_ip, headers)?; + + // 检查代理链长度 + if chain_analysis.hops > self.config.max_hops { + return Err(ProxyError::ChainTooLong(chain_analysis.hops, self.config.max_hops)); + } + + // 检查链连续性(如果启用) + if self.config.enable_chain_continuity_check && !chain_analysis.is_continuous { + return Err(ProxyError::ChainNotContinuous); + } + + // 创建客户端信息 + let warnings = if !chain_analysis.warnings.is_empty() { + chain_analysis.warnings + } else { + Vec::new() + }; + + Ok(ClientInfo::from_trusted_proxy( + chain_analysis.client_ip, + client_info.forwarded_host, + client_info.forwarded_proto, + proxy_ip, + chain_analysis.hops, + self.config.validation_mode, + warnings, + )) + } + + /// 尝试解析 RFC 7239 Forwarded 头部 + fn try_parse_rfc7239_headers(&self, headers: &HeaderMap, proxy_ip: IpAddr) -> Option { + headers + .get("forwarded") + .and_then(|h| h.to_str().ok()) + .and_then(|s| Self::parse_forwarded_header(s, proxy_ip)) + } + + /// 解析传统的代理头部 + fn parse_legacy_headers(&self, headers: &HeaderMap, proxy_ip: IpAddr) -> ParsedHeaders { + let forwarded_host = headers + .get("x-forwarded-host") + .and_then(|h| h.to_str().ok()) + .map(String::from); + + let forwarded_proto = headers + .get("x-forwarded-proto") + .and_then(|h| h.to_str().ok()) + .map(String::from); + + let proxy_chain = headers + .get("x-forwarded-for") + .and_then(|h| h.to_str().ok()) + .map(|s| Self::parse_x_forwarded_for(s)) + .unwrap_or_else(Vec::new); + + ParsedHeaders { + proxy_chain, + forwarded_host, + forwarded_proto, + } + } + + /// 解析 RFC 7239 Forwarded 头部 + fn parse_forwarded_header(header_value: &str, proxy_ip: IpAddr) -> Option { + // 简化实现:只处理第一个值 + let first_part = header_value.split(',').next()?.trim(); + + let mut proxy_chain = Vec::new(); + let mut forwarded_host = None; + let mut forwarded_proto = None; + + // 解析键值对 + for part in first_part.split(';') { + let part = part.trim(); + if let Some((key, value)) = part.split_once('=') { + let key = key.trim().to_lowercase(); + let value = value.trim().trim_matches('"'); + + match key.as_str() { + "for" => { + // 解析客户端 IP(可能包含端口) + if let Some(ip_part) = value.split(':').next() { + if let Ok(ip) = ip_part.parse::() { + proxy_chain.push(ip); + } + } + } + "host" => { + forwarded_host = Some(value.to_string()); + } + "proto" => { + forwarded_proto = Some(value.to_string()); + } + _ => {} + } + } + } + + // 如果没有找到客户端 IP,添加代理 IP 作为备选 + if proxy_chain.is_empty() { + proxy_chain.push(proxy_ip); + } + + Some(ParsedHeaders { + proxy_chain, + forwarded_host, + forwarded_proto, + }) + } + + /// 解析 X-Forwarded-For 头部 + fn parse_x_forwarded_for(header_value: &str) -> Vec { + header_value + .split(',') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .filter_map(|s| { + // 移除端口部分(如果存在) + let ip_part = s.split(':').next().unwrap_or(s); + ip_part.parse::().ok() + }) + .collect() + } + + /// 记录验证开始指标 + fn record_metric_start(&self) { + if let Some(metrics) = &self.metrics { + metrics.increment_validation_attempts(); + } + } + + /// 记录验证结果指标 + fn record_metric_result(&self, result: &Result, duration: std::time::Duration) { + if let Some(metrics) = &self.metrics { + match result { + Ok(client_info) => { + metrics.record_validation_success(client_info.is_from_trusted_proxy, client_info.proxy_hops, duration); + } + Err(err) => { + metrics.record_validation_failure(err, duration); + + // 记录失败的验证(如果启用) + if self.config.log_failed_validations { + warn!("Proxy validation failed: {}", err); + } + } + } + } + } +} + +/// 解析后的头部信息 +#[derive(Debug, Clone)] +struct ParsedHeaders { + /// 代理链(客户端 IP 在第一个位置) + proxy_chain: Vec, + /// 转发的主机名 + forwarded_host: Option, + /// 转发的协议 + forwarded_proto: Option, +} diff --git a/crates/trusted-proxies/src/state.rs b/crates/trusted-proxies/src/state.rs new file mode 100644 index 00000000..d5a94e18 --- /dev/null +++ b/crates/trusted-proxies/src/state.rs @@ -0,0 +1,25 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::{AppConfig, ProxyMetrics}; +use std::sync::Arc; + +/// 应用状态 +#[derive(Clone)] +pub struct AppState { + /// 应用配置 + pub config: Arc, + /// 代理指标收集器 + pub metrics: Option, +} diff --git a/crates/trusted-proxies/src/utils/ip.rs b/crates/trusted-proxies/src/utils/ip.rs index 6238cfff..2af0a06f 100644 --- a/crates/trusted-proxies/src/utils/ip.rs +++ b/crates/trusted-proxies/src/utils/ip.rs @@ -11,3 +11,255 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! IP address utility functions + +use ipnetwork::IpNetwork; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::str::FromStr; + +/// IP 工具函数集合 +pub struct IpUtils; + +impl IpUtils { + /// 检查 IP 地址是否有效 + pub fn is_valid_ip_address(ip: &IpAddr) -> bool { + !ip.is_unspecified() && !ip.is_multicast() && !Self::is_reserved_ip(ip) + } + + /// 检查 IP 是否为保留地址 + pub fn is_reserved_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => Self::is_reserved_ipv4(ipv4), + IpAddr::V6(ipv6) => Self::is_reserved_ipv6(ipv6), + } + } + + /// 检查 IPv4 是否为保留地址 + pub fn is_reserved_ipv4(ip: &Ipv4Addr) -> bool { + let octets = ip.octets(); + + // 检查常见的保留地址范围 + matches!( + octets, + [0, _, _, _] | // 0.0.0.0/8 + [10, _, _, _] | // 10.0.0.0/8 + [100, 64, _, _] | // 100.64.0.0/10 + [127, _, _, _] | // 127.0.0.0/8 + [169, 254, _, _] | // 169.254.0.0/16 + [172, 16..=31, _, _] | // 172.16.0.0/12 + [192, 0, 0, _] | // 192.0.0.0/24 + [192, 0, 2, _] | // 192.0.2.0/24 + [192, 88, 99, _] | // 192.88.99.0/24 + [192, 168, _, _] | // 192.168.0.0/16 + [198, 18..=19, _, _] | // 198.18.0.0/15 + [198, 51, 100, _] | // 198.51.100.0/24 + [203, 0, 113, _] | // 203.0.113.0/24 + [224..=239, _, _, _] | // 224.0.0.0/4 + [240..=255, _, _, _] // 240.0.0.0/4 + ) + } + + /// 检查 IPv6 是否为保留地址 + pub fn is_reserved_ipv6(ip: &Ipv6Addr) -> bool { + let segments = ip.segments(); + + // 检查常见的保留地址范围 + matches!( + segments, + [0, 0, 0, 0, 0, 0, 0, 0] | // ::/128 + [0, 0, 0, 0, 0, 0, 0, 1] | // ::1/128 + [0x2001, 0xdb8, _, _, _, _, _, _] | // 2001:db8::/32 + [0xfc00..=0xfdff, _, _, _, _, _, _, _] | // fc00::/7 + [0xfe80..=0xfebf, _, _, _, _, _, _, _] | // fe80::/10 + [0xff00..=0xffff, _, _, _, _, _, _, _] // ff00::/8 + ) + } + + /// 检查 IP 是否为私有地址 + pub fn is_private_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => Self::is_private_ipv4(ipv4), + IpAddr::V6(ipv6) => Self::is_private_ipv6(ipv6), + } + } + + /// 检查 IPv4 是否为私有地址 + pub fn is_private_ipv4(ip: &Ipv4Addr) -> bool { + let octets = ip.octets(); + + matches!( + octets, + [10, _, _, _] | // 10.0.0.0/8 + [172, 16..=31, _, _] | // 172.16.0.0/12 + [192, 168, _, _] // 192.168.0.0/16 + ) + } + + /// 检查 IPv6 是否为私有地址 + pub fn is_private_ipv6(ip: &Ipv6Addr) -> bool { + let segments = ip.segments(); + + matches!( + segments, + [0xfc00..=0xfdff, _, _, _, _, _, _, _] // fc00::/7 + ) + } + + /// 检查 IP 是否为回环地址 + pub fn is_loopback_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => ipv4.is_loopback(), + IpAddr::V6(ipv6) => ipv6.is_loopback(), + } + } + + /// 检查 IP 是否为链路本地地址 + pub fn is_link_local_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => ipv4.is_link_local(), + IpAddr::V6(ipv6) => ipv6.is_unicast_link_local(), + } + } + + /// 检查 IP 是否为文档地址(TEST-NET) + pub fn is_documentation_ip(ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => { + let octets = ipv4.octets(); + matches!( + octets, + [192, 0, 2, _] | // 192.0.2.0/24 + [198, 51, 100, _] | // 198.51.100.0/24 + [203, 0, 113, _] // 203.0.113.0/24 + ) + } + IpAddr::V6(ipv6) => { + let segments = ipv6.segments(); + matches!(segments, [0x2001, 0xdb8, _, _, _, _, _, _]) // 2001:db8::/32 + } + } + } + + /// 从字符串解析 IP 地址,支持 CIDR 表示法 + pub fn parse_ip_or_cidr(s: &str) -> Result { + IpNetwork::from_str(s).map_err(|e| format!("Failed to parse IP/CIDR '{}': {}", s, e)) + } + + /// 从逗号分隔的字符串解析 IP 列表 + pub fn parse_ip_list(s: &str) -> Result, String> { + let mut ips = Vec::new(); + + for part in s.split(',') { + let part = part.trim(); + if part.is_empty() { + continue; + } + + match IpAddr::from_str(part) { + Ok(ip) => ips.push(ip), + Err(e) => return Err(format!("Failed to parse IP '{}': {}", part, e)), + } + } + + Ok(ips) + } + + /// 从逗号分隔的字符串解析网络列表 + pub fn parse_network_list(s: &str) -> Result, String> { + let mut networks = Vec::new(); + + for part in s.split(',') { + let part = part.trim(); + if part.is_empty() { + continue; + } + + match Self::parse_ip_or_cidr(part) { + Ok(network) => networks.push(network), + Err(e) => return Err(e), + } + } + + Ok(networks) + } + + /// 检查 IP 是否在给定的网络列表中 + pub fn ip_in_networks(ip: &IpAddr, networks: &[IpNetwork]) -> bool { + networks.iter().any(|network| network.contains(*ip)) + } + + /// 获取 IP 地址的类型描述 + pub fn get_ip_type(ip: &IpAddr) -> &'static str { + if Self::is_private_ip(ip) { + "private" + } else if Self::is_loopback_ip(ip) { + "loopback" + } else if Self::is_link_local_ip(ip) { + "link_local" + } else if Self::is_documentation_ip(ip) { + "documentation" + } else if Self::is_reserved_ip(ip) { + "reserved" + } else { + "public" + } + } + + /// 将 IP 地址转换为规范形式 + pub fn canonical_ip(ip: &IpAddr) -> String { + match ip { + IpAddr::V4(ipv4) => ipv4.to_string(), + IpAddr::V6(ipv6) => { + // 压缩 IPv6 地址 + let mut result = String::new(); + let segments = ipv6.segments(); + + // 查找最长的连续零段 + let mut longest_start = 0; + let mut longest_len = 0; + let mut current_start = 0; + let mut current_len = 0; + + for (i, &segment) in segments.iter().enumerate() { + if segment == 0 { + if current_len == 0 { + current_start = i; + } + current_len += 1; + } else { + if current_len > longest_len { + longest_start = current_start; + longest_len = current_len; + } + current_len = 0; + } + } + + if current_len > longest_len { + longest_start = current_start; + longest_len = current_len; + } + + // 格式化为字符串 + for mut i in 0..8 { + if i == longest_start && longest_len > 1 { + result.push_str("::"); + i += longest_len - 1; + } else if i == longest_start && longest_len == 1 { + result.push('0'); + } else { + if i > 0 && i != longest_start { + result.push(':'); + } + if segments[i] != 0 || (i == 7 && result.is_empty()) { + result.push_str(&format!("{:x}", segments[i])); + } + } + } + + result + } + } + } +} diff --git a/crates/trusted-proxies/src/utils/mod.rs b/crates/trusted-proxies/src/utils/mod.rs index f087e3e2..e40c4668 100644 --- a/crates/trusted-proxies/src/utils/mod.rs +++ b/crates/trusted-proxies/src/utils/mod.rs @@ -12,5 +12,78 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Utility functions and helpers + mod ip; mod validation; + +pub use ip::*; +pub use validation::*; + +/// 工具函数集合 +#[derive(Debug, Clone)] +pub struct Utils; + +impl Utils { + /// 生成追踪 ID + pub fn generate_trace_id() -> String { + format!("trace-{}", uuid::Uuid::new_v4()) + } + + /// 生成 Span ID + pub fn generate_span_id() -> String { + format!("span-{}", uuid::Uuid::new_v4()) + } + + /// 安全的将字符串转换为 usize + pub fn safe_parse_usize(s: &str, default: usize) -> usize { + s.parse().unwrap_or(default) + } + + /// 安全的将字符串转换为 u64 + pub fn safe_parse_u64(s: &str, default: u64) -> u64 { + s.parse().unwrap_or(default) + } + + /// 安全的将字符串转换为布尔值 + pub fn safe_parse_bool(s: &str, default: bool) -> bool { + match s.to_lowercase().as_str() { + "true" | "1" | "yes" | "on" => true, + "false" | "0" | "no" | "off" => false, + _ => default, + } + } + + /// 格式化持续时间 + pub fn format_duration(duration: std::time::Duration) -> String { + if duration.as_secs() > 0 { + format!("{:.2}s", duration.as_secs_f64()) + } else if duration.as_millis() > 0 { + format!("{}ms", duration.as_millis()) + } else if duration.as_micros() > 0 { + format!("{}µs", duration.as_micros()) + } else { + format!("{}ns", duration.as_nanos()) + } + } + + /// 获取当前时间戳 + pub fn current_timestamp() -> String { + chrono::Utc::now().to_rfc3339() + } + + /// 安全的获取环境变量 + pub fn get_env_var(key: &str) -> Option { + std::env::var(key).ok() + } + + /// 获取环境变量,如果不存在则使用默认值 + pub fn get_env_var_or(key: &str, default: &str) -> String { + std::env::var(key).unwrap_or_else(|_| default.to_string()) + } + + /// 检查环境变量是否存在 + pub fn has_env_var(key: &str) -> bool { + std::env::var(key).is_ok() + } +} diff --git a/crates/trusted-proxies/src/utils/validation.rs b/crates/trusted-proxies/src/utils/validation.rs index 6238cfff..e4bdbf43 100644 --- a/crates/trusted-proxies/src/utils/validation.rs +++ b/crates/trusted-proxies/src/utils/validation.rs @@ -11,3 +11,203 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Validation utility functions + +use http::HeaderMap; +use lazy_static::lazy_static; +use regex::Regex; +use std::net::IpAddr; +use std::str::FromStr; + +/// 验证工具函数集合 +pub struct ValidationUtils; + +impl ValidationUtils { + /// 验证电子邮件地址 + pub fn is_valid_email(email: &str) -> bool { + lazy_static! { + static ref EMAIL_REGEX: Regex = + Regex::new(r"^([a-z0-9_+]([a-z0-9_+.]*[a-z0-9_+])?)@([a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,6})").unwrap(); + } + + EMAIL_REGEX.is_match(email) + } + + /// 验证 URL + pub fn is_valid_url(url: &str) -> bool { + lazy_static! { + static ref URL_REGEX: Regex = + Regex::new(r"^(https?://)?([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,6}(/.*)?$").unwrap(); + } + + URL_REGEX.is_match(url) + } + + /// 验证 X-Forwarded-For 头部 + pub fn validate_x_forwarded_for(header_value: &str) -> bool { + if header_value.is_empty() { + return false; + } + + // 分割 IP 地址 + let ips: Vec<&str> = header_value.split(',').map(|s| s.trim()).collect(); + + // 检查每个 IP 地址 + for ip_str in ips { + if ip_str.is_empty() { + return false; + } + + // 移除端口部分(如果存在) + let ip_part = ip_str.split(':').next().unwrap_or(ip_str); + + if IpAddr::from_str(ip_part).is_err() { + return false; + } + } + + true + } + + /// 验证 Forwarded 头部(RFC 7239) + pub fn validate_forwarded_header(header_value: &str) -> bool { + if header_value.is_empty() { + return false; + } + + // 简化的验证:检查基本格式 + let parts: Vec<&str> = header_value.split(';').collect(); + + if parts.is_empty() { + return false; + } + + // 检查每个部分是否包含等号 + for part in parts { + let part = part.trim(); + if !part.contains('=') { + return false; + } + } + + true + } + + /// 验证 IP 地址是否在允许的范围内 + pub fn validate_ip_in_range(ip: &IpAddr, cidr_ranges: &[String]) -> bool { + for cidr in cidr_ranges { + if let Ok(network) = ipnetwork::IpNetwork::from_str(cidr) { + if network.contains(*ip) { + return true; + } + } + } + + false + } + + /// 验证头部是否包含恶意内容 + pub fn validate_header_value(value: &str) -> bool { + // 检查是否包含控制字符(除了水平制表符) + for c in value.chars() { + if c.is_control() && c != '\t' && c != '\n' && c != '\r' { + return false; + } + } + + // 检查长度限制(防止头部过大攻击) + if value.len() > 8192 { + return false; + } + + true + } + + /// 验证整个头部映射 + pub fn validate_headers(headers: &HeaderMap) -> bool { + for (name, value) in headers { + // 检查头部名称 + let name_str = name.as_str(); + if name_str.len() > 256 { + return false; + } + + // 检查头部值 + if let Ok(value_str) = value.to_str() { + if !Self::validate_header_value(value_str) { + return false; + } + } else { + // 无法转换为字符串,可能包含二进制数据 + if value.len() > 8192 { + return false; + } + } + } + + true + } + + /// 验证端口号 + pub fn validate_port(port: u16) -> bool { + port > 0 && port <= 65535 + } + + /// 验证 CIDR 表示法 + pub fn validate_cidr(cidr: &str) -> bool { + ipnetwork::IpNetwork::from_str(cidr).is_ok() + } + + /// 验证代理链长度 + pub fn validate_proxy_chain_length(chain: &[IpAddr], max_length: usize) -> bool { + chain.len() <= max_length + } + + /// 验证代理链是否连续 + pub fn validate_proxy_chain_continuity(chain: &[IpAddr]) -> bool { + if chain.len() < 2 { + return true; + } + + // 检查是否有重复的相邻 IP + for i in 1..chain.len() { + if chain[i] == chain[i - 1] { + return false; + } + } + + true + } + + /// 验证字符串是否只包含安全字符 + pub fn is_safe_string(s: &str) -> bool { + // 允许的字符:字母、数字、基本标点符号 + let safe_pattern = Regex::new(r"^[a-zA-Z0-9\-._~:/?#\[\]@!$&'()*+,;=]+$").unwrap(); + safe_pattern.is_match(s) + } + + /// 验证速率限制参数 + pub fn validate_rate_limit_params(requests: u32, period_seconds: u64) -> bool { + requests > 0 && requests <= 10000 && period_seconds > 0 && period_seconds <= 86400 + } + + /// 验证缓存参数 + pub fn validate_cache_params(capacity: usize, ttl_seconds: u64) -> bool { + capacity > 0 && capacity <= 1000000 && ttl_seconds > 0 && ttl_seconds <= 86400 + } + + /// 脱敏敏感数据 + pub fn mask_sensitive_data(data: &str, sensitive_patterns: &[&str]) -> String { + let mut result = data.to_string(); + + for pattern in sensitive_patterns { + let regex = Regex::new(&format!(r#"(?i){}[:=]\s*([^&\s]+)"#, pattern)).unwrap(); + result = regex + .replace_all(&result, |caps: ®ex::Captures| format!("{}:[REDACTED]", &caps[1])) + .to_string(); + } + + result + } +} diff --git a/crates/trusted-proxies/test/integration/api_tests.rs b/crates/trusted-proxies/test/integration/api_tests.rs new file mode 100644 index 00000000..6cac9353 --- /dev/null +++ b/crates/trusted-proxies/test/integration/api_tests.rs @@ -0,0 +1,178 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! API integration tests + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use axum::body::Body; + use axum::{extract::State, routing::get, Router}; + use serde_json::{json, Value}; + use tower::ServiceExt; + + use crate::config::{AppConfig, TrustedProxy, TrustedProxyConfig, ValidationMode}; + use crate::middleware::TrustedProxyLayer; + use crate::AppState; + + fn create_test_app_state() -> AppState { + let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; + + let proxy_config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]); + + let config = AppConfig::new( + proxy_config, + crate::config::CacheConfig::default(), + crate::config::MonitoringConfig::default(), + crate::config::CloudConfig::default(), + "127.0.0.1:3000".parse().unwrap(), + ); + + AppState { + config: Arc::new(config), + metrics: None, + } + } + + fn create_test_api_router() -> Router { + let state = create_test_app_state(); + let proxy_layer = TrustedProxyLayer::enabled(state.config.proxy.clone(), None); + + Router::new() + .route("/health", get(health_check)) + .route("/config", get(show_config)) + .with_state(state) + .layer(proxy_layer) + } + + async fn health_check() -> axum::response::Json { + axum::response::Json(json!({ + "status": "healthy", + "service": "trusted-proxy-test" + })) + } + + async fn show_config(State(state): State) -> axum::response::Json { + axum::response::Json(json!({ + "server": state.config.server_addr.to_string(), + "proxy": { + "trusted_networks": state.config.proxy.proxies.len(), + } + })) + } + + #[tokio::test] + async fn test_health_check_endpoint() { + let app = create_test_api_router(); + + let request = axum::http::Request::builder().uri("/health").body(Body::empty()).unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let json: Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["status"], "healthy"); + assert_eq!(json["service"], "trusted-proxy-test"); + } + + #[tokio::test] + async fn test_config_endpoint() { + let app = create_test_api_router(); + + let request = axum::http::Request::builder().uri("/config").body(Body::empty()).unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let json: Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["server"], "127.0.0.1:3000"); + assert_eq!(json["proxy"]["trusted_networks"], 1); + } + + #[tokio::test] + async fn test_proxy_headers_in_api() { + let state = create_test_app_state(); + let proxy_layer = TrustedProxyLayer::enabled(state.config.proxy.clone(), None); + + let app = Router::new() + .route( + "/client-test", + get(|req: axum::extract::Request| async move { + let client_info = req.extensions().get::(); + match client_info { + Some(info) => axum::response::Json(json!({ + "client_ip": info.real_ip.to_string(), + "trusted": info.is_from_trusted_proxy + })), + None => axum::response::Json(json!({ + "error": "No client info" + })), + } + }), + ) + .with_state(state) + .layer(proxy_layer); + + // 测试带代理头部的请求 + let request = axum::http::Request::builder() + .uri("/client-test") + .header("X-Forwarded-For", "203.0.113.195") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let json: Value = serde_json::from_slice(&body).unwrap(); + + // 由于请求来自 127.0.0.1(可信代理),应该解析 X-Forwarded-For + if json.get("client_ip").is_some() { + let client_ip = json["client_ip"].as_str().unwrap(); + // 可能是 203.0.113.195 或 127.0.0.1,取决于中间件如何配置 + assert!(client_ip == "203.0.113.195" || client_ip == "127.0.0.1"); + } + } + + #[tokio::test] + async fn test_missing_endpoint() { + let app = create_test_api_router(); + + let request = axum::http::Request::builder().uri("/not-found").body(Body::empty()).unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + // 应该返回 404 + assert_eq!(response.status(), 404); + } + + #[tokio::test] + async fn test_request_without_proxy_layer() { + // 创建没有代理中间件的路由 + let app = Router::new().route("/simple", get(|| async { "OK" })); + + let request = axum::http::Request::builder().uri("/simple").body(Body::empty()).unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(String::from_utf8(body.to_vec()).unwrap(), "OK"); + } +} diff --git a/crates/trusted-proxies/test/integration/cloud_tests.rs b/crates/trusted-proxies/test/integration/cloud_tests.rs index 6238cfff..4c8d69f8 100644 --- a/crates/trusted-proxies/test/integration/cloud_tests.rs +++ b/crates/trusted-proxies/test/integration/cloud_tests.rs @@ -11,3 +11,183 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Cloud metadata integration tests + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::cloud::detector::CloudDetector; + use crate::cloud::metadata::{AwsMetadataFetcher, AzureMetadataFetcher, GcpMetadataFetcher}; + use crate::cloud::ranges::{CloudflareIpRanges, GoogleCloudIpRanges}; + + #[tokio::test] + async fn test_cloud_detector_disabled() { + let detector = CloudDetector::new( + false, // 禁用 + std::time::Duration::from_secs(1), + None, + ); + + let provider = detector.detect_provider(); + assert!(provider.is_none()); + + let ranges = detector.fetch_trusted_ranges().await; + assert!(ranges.is_ok()); + assert!(ranges.unwrap().is_empty()); + } + + #[tokio::test] + async fn test_cloud_detector_forced_provider() { + let detector = CloudDetector::new(true, std::time::Duration::from_secs(1), Some("aws".to_string())); + + let provider = detector.detect_provider(); + assert!(provider.is_some()); + assert_eq!(provider.unwrap().name(), "aws"); + } + + #[tokio::test] + async fn test_aws_metadata_fetcher() { + let fetcher = AwsMetadataFetcher::new(); + + // 测试提供者名称 + assert_eq!(fetcher.provider_name(), "aws"); + + // 由于不在 AWS 环境中,这些调用应该失败或返回默认值 + let network_result = fetcher.fetch_network_cidrs().await; + // 可能返回默认范围或错误 + assert!(network_result.is_ok()); + + let public_result = fetcher.fetch_public_ip_ranges().await; + // 可能从 API 获取或返回空列表 + assert!(public_result.is_ok()); + } + + #[tokio::test] + async fn test_azure_metadata_fetcher() { + let fetcher = AzureMetadataFetcher::new(); + + // 测试提供者名称 + assert_eq!(fetcher.provider_name(), "azure"); + + // 由于不在 Azure 环境中,这些调用应该返回默认值 + let network_result = fetcher.fetch_network_cidrs().await; + assert!(network_result.is_ok()); + + let public_result = fetcher.fetch_public_ip_ranges().await; + assert!(public_result.is_ok()); + } + + #[tokio::test] + async fn test_gcp_metadata_fetcher() { + let fetcher = GcpMetadataFetcher::new(); + + // 测试提供者名称 + assert_eq!(fetcher.provider_name(), "gcp"); + + // 由于不在 GCP 环境中,这些调用应该返回默认值 + let network_result = fetcher.fetch_network_cidrs().await; + assert!(network_result.is_ok()); + + let public_result = fetcher.fetch_public_ip_ranges().await; + assert!(public_result.is_ok()); + } + + #[tokio::test] + async fn test_cloudflare_ip_ranges_static() { + let ranges = CloudflareIpRanges::fetch().await; + + assert!(ranges.is_ok()); + + let networks = ranges.unwrap(); + assert!(!networks.is_empty()); + + // 检查是否包含预期的范围 + let has_ipv4 = networks + .iter() + .any(|n| n.to_string().contains("103.21.244.0/22") || n.to_string().contains("198.41.128.0/17")); + + let has_ipv6 = networks + .iter() + .any(|n| n.to_string().contains("2400:cb00::/32") || n.to_string().contains("2606:4700::/32")); + + assert!(has_ipv4 || has_ipv6); + } + + #[tokio::test] + async fn test_google_cloud_ip_ranges_api_mock() { + // 创建模拟服务器 + let mock_server = MockServer::start().await; + + // 模拟 Google IP 范围 API 响应 + let mock_response = r#" + { + "prefixes": [ + {"ipv4Prefix": "8.34.208.0/20"}, + {"ipv4Prefix": "8.35.192.0/20"}, + {"ipv6Prefix": "2001:4860::/32"} + ] + } + "#; + + Mock::given(method("GET")) + .and(path("/ipranges/cloud.json")) + .respond_with(ResponseTemplate::new(200).set_body_string(mock_response)) + .mount(&mock_server) + .await; + + // 创建自定义客户端指向模拟服务器 + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(2)) + .build() + .unwrap(); + + let url = format!("{}/ipranges/cloud.json", mock_server.uri()); + + let response = client.get(&url).send().await.unwrap(); + assert_eq!(response.status(), 200); + + let body = response.text().await.unwrap(); + assert!(body.contains("8.34.208.0/20")); + assert!(body.contains("2001:4860::/32")); + } + + #[tokio::test] + async fn test_cloud_detector_try_all_providers() { + let detector = CloudDetector::new(true, std::time::Duration::from_secs(2), None); + + // 在测试环境中,所有提供者都应该失败或返回空列表 + let result = detector.try_all_providers().await; + + // 应该成功返回(即使是空列表) + assert!(result.is_ok()); + } + + #[test] + fn test_ip_network_parsing() { + // 测试 CIDR 解析 + let cidr = ipnetwork::IpNetwork::from_str("192.168.1.0/24"); + assert!(cidr.is_ok()); + + let network = cidr.unwrap(); + assert_eq!(network.prefix(), 24); + + // 测试 IP 包含检查 + let ip: std::net::IpAddr = "192.168.1.100".parse().unwrap(); + assert!(network.contains(ip)); + + let ip_outside: std::net::IpAddr = "192.168.2.100".parse().unwrap(); + assert!(!network.contains(ip_outside)); + + // 测试 IPv6 CIDR + let ipv6_cidr = ipnetwork::IpNetwork::from_str("2001:db8::/32"); + assert!(ipv6_cidr.is_ok()); + + let ipv6_network = ipv6_cidr.unwrap(); + assert_eq!(ipv6_network.prefix(), 32); + } +} diff --git a/crates/trusted-proxies/test/integration/mod.rs b/crates/trusted-proxies/test/integration/mod.rs index 6238cfff..fa0bece3 100644 --- a/crates/trusted-proxies/test/integration/mod.rs +++ b/crates/trusted-proxies/test/integration/mod.rs @@ -11,3 +11,14 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Integration tests for the trusted proxy system + +mod api_tests; +mod cloud_tests; +mod proxy_tests; + +// 重新导出测试模块 +pub use api_tests::*; +pub use cloud_tests::*; +pub use proxy_tests::*; diff --git a/crates/trusted-proxies/test/integration/proxy_tests.rs b/crates/trusted-proxies/test/integration/proxy_tests.rs index 6238cfff..381e2a10 100644 --- a/crates/trusted-proxies/test/integration/proxy_tests.rs +++ b/crates/trusted-proxies/test/integration/proxy_tests.rs @@ -11,3 +11,178 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Proxy system integration tests + +#[cfg(test)] +mod tests { + use std::error::Request; + use axum::body::Body; + use axum::{extract::Request, routing::get, Router}; + use tower::ServiceExt; + + use crate::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode}; + use crate::middleware::{ClientInfo, TrustedProxyLayer}; + + fn create_test_router() -> Router { + let proxies = vec![ + TrustedProxy::Single("127.0.0.1".parse().unwrap()), + TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), + ]; + + let config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]); + + let proxy_layer = TrustedProxyLayer::enabled(config, None); + + Router::new() + .route( + "/test", + get(|req: Request| async move { + let client_info = req.extensions().get::(); + match client_info { + Some(info) => { + format!("IP: {}, Trusted: {}, Hops: {}", info.real_ip, info.is_from_trusted_proxy, info.proxy_hops) + } + None => "No client info".to_string(), + } + }), + ) + .layer(proxy_layer) + } + + #[tokio::test] + async fn test_direct_connection() { + let app = create_test_router(); + + // 模拟直接连接(无代理头部) + let request = axum::http::Request::builder().uri("/test").body(Body::empty()).unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body_str = String::from_utf8(body.to_vec()).unwrap(); + + // 应该显示直接连接的 IP(在测试环境中可能是 0.0.0.0) + assert!(body_str.contains("IP:")); + } + + #[tokio::test] + async fn test_trusted_proxy_with_xff() { + let app = create_test_router(); + + // 模拟来自可信代理的请求 + let request = axum::http::Request::builder() + .uri("/test") + .header("X-Forwarded-For", "203.0.113.195, 10.0.1.100") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body_str = String::from_utf8(body.to_vec()).unwrap(); + + // 应该显示客户端 IP (203.0.113.195) + assert!(body_str.contains("203.0.113.195")); + assert!(body_str.contains("Trusted: true")); + } + + #[tokio::test] + async fn test_untrusted_proxy_with_xff() { + let app = create_test_router(); + + // 模拟来自不可信代理的请求 + let request = axum::http::Request::builder() + .uri("/test") + .header("X-Forwarded-For", "203.0.113.195") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body_str = String::from_utf8(body.to_vec()).unwrap(); + + // 由于请求不是来自可信代理,X-Forwarded-For 应该被忽略 + // 应该显示直接连接的 IP + assert!(!body_str.contains("203.0.113.195")); + } + + #[tokio::test] + async fn test_proxy_chain_too_long() { + let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; + + let config = TrustedProxyConfig::new( + proxies, + ValidationMode::Strict, + true, + 3, // 最大 3 跳 + true, + vec![], + ); + + let proxy_layer = TrustedProxyLayer::enabled(config, None); + + let app = Router::new().route("/test", get(|| async { "OK" })).layer(proxy_layer); + + // 模拟超长代理链 + let xff_value = (0..5).map(|i| format!("10.0.{}.1", i)).collect::>().join(", "); + + let request = axum::http::Request::builder() + .uri("/test") + .header("X-Forwarded-For", xff_value) + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + // 由于代理链太长,验证应该失败 + // 注意:中间件可能会降级处理,而不是直接拒绝 + assert_eq!(response.status(), 200); + } + + #[tokio::test] + async fn test_rfc7239_forwarded_header() { + let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())]; + + let config = TrustedProxyConfig::new( + proxies, + ValidationMode::HopByHop, + true, // 启用 RFC 7239 + 10, + true, + vec![], + ); + + let proxy_layer = TrustedProxyLayer::enabled(config, None); + + let app = Router::new() + .route( + "/test", + get(|req: Request| async move { + let client_info = req.extensions().get::().unwrap(); + format!("IP: {}", client_info.real_ip) + }), + ) + .layer(proxy_layer); + + // 模拟使用 RFC 7239 Forwarded 头部的请求 + let request = axum::http::Request::builder() + .uri("/test") + .header("Forwarded", r#"for=192.0.2.60;proto=https;by=203.0.113.43"#) + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), 200); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body_str = String::from_utf8(body.to_vec()).unwrap(); + + // 应该解析 RFC 7239 头部 + assert!(body_str.contains("192.0.2.60")); + } +} diff --git a/crates/trusted-proxies/test/unit/config_tests.rs b/crates/trusted-proxies/test/unit/config_tests.rs index 6238cfff..61ce51d7 100644 --- a/crates/trusted-proxies/test/unit/config_tests.rs +++ b/crates/trusted-proxies/test/unit/config_tests.rs @@ -11,3 +11,173 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Configuration module unit tests + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + use std::str::FromStr; + + use crate::config::env::{DEFAULT_TRUSTED_PROXIES, ENV_TRUSTED_PROXIES}; + use crate::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode}; + + #[test] + fn test_config_loader_default() { + // 清理环境变量 + std::env::remove_var(ENV_TRUSTED_PROXIES); + + let config = ConfigLoader::from_env_or_default(); + + // 验证默认值 + assert_eq!(config.server_addr.port(), 3000); + assert!(!config.proxy.proxies.is_empty()); + assert_eq!(config.proxy.validation_mode, ValidationMode::HopByHop); + assert!(config.proxy.enable_rfc7239); + assert_eq!(config.proxy.max_hops, 10); + } + + #[test] + fn test_config_loader_env_vars() { + // 设置环境变量 + std::env::set_var(ENV_TRUSTED_PROXIES, "192.168.1.0/24,10.0.0.0/8"); + std::env::set_var("TRUSTED_PROXY_VALIDATION_MODE", "strict"); + std::env::set_var("TRUSTED_PROXY_MAX_HOPS", "5"); + std::env::set_var("SERVER_PORT", "8080"); + + let config = ConfigLoader::from_env(); + + if let Ok(config) = config { + assert_eq!(config.server_addr.port(), 8080); + assert_eq!(config.proxy.validation_mode, ValidationMode::Strict); + assert_eq!(config.proxy.max_hops, 5); + + // 清理环境变量 + std::env::remove_var(ENV_TRUSTED_PROXIES); + std::env::remove_var("TRUSTED_PROXY_VALIDATION_MODE"); + std::env::remove_var("TRUSTED_PROXY_MAX_HOPS"); + std::env::remove_var("SERVER_PORT"); + } else { + panic!("Failed to load config from env"); + } + } + + #[test] + fn test_trusted_proxy_config() { + let proxies = vec![ + TrustedProxy::Single("192.168.1.1".parse().unwrap()), + TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), + ]; + + let config = TrustedProxyConfig::new(proxies.clone(), ValidationMode::Strict, true, 10, true, vec![]); + + assert_eq!(config.proxies.len(), 2); + assert_eq!(config.validation_mode, ValidationMode::Strict); + assert!(config.enable_rfc7239); + assert_eq!(config.max_hops, 10); + assert!(config.enable_chain_continuity_check); + + // 测试 IP 检查 + let test_ip: IpAddr = "192.168.1.1".parse().unwrap(); + let test_socket_addr = std::net::SocketAddr::new(test_ip, 8080); + + assert!(config.is_trusted(&test_socket_addr)); + + let test_ip2: IpAddr = "10.0.1.1".parse().unwrap(); + let test_socket_addr2 = std::net::SocketAddr::new(test_ip2, 8080); + + assert!(config.is_trusted(&test_socket_addr2)); + } + + #[test] + fn test_validation_mode_from_str() { + assert_eq!(ValidationMode::from_str("lenient").unwrap(), ValidationMode::Lenient); + assert_eq!(ValidationMode::from_str("strict").unwrap(), ValidationMode::Strict); + assert_eq!(ValidationMode::from_str("hop_by_hop").unwrap(), ValidationMode::HopByHop); + + // 测试无效值 + assert!(ValidationMode::from_str("invalid").is_err()); + } + + #[test] + fn test_trusted_proxy_contains() { + // 测试单个 IP + let single_proxy = TrustedProxy::Single("192.168.1.1".parse().unwrap()); + let test_ip: IpAddr = "192.168.1.1".parse().unwrap(); + let test_ip2: IpAddr = "192.168.1.2".parse().unwrap(); + + assert!(single_proxy.contains(&test_ip)); + assert!(!single_proxy.contains(&test_ip2)); + + // 测试 CIDR 范围 + let cidr_proxy = TrustedProxy::Cidr("192.168.1.0/24".parse().unwrap()); + + assert!(cidr_proxy.contains(&test_ip)); + assert!(cidr_proxy.contains(&test_ip2)); + + let test_ip3: IpAddr = "192.168.2.1".parse().unwrap(); + assert!(!cidr_proxy.contains(&test_ip3)); + } + + #[test] + fn test_private_network_check() { + let config = TrustedProxyConfig::new( + Vec::new(), + ValidationMode::Lenient, + true, + 10, + true, + vec!["10.0.0.0/8".parse().unwrap(), "192.168.0.0/16".parse().unwrap()], + ); + + let private_ip: IpAddr = "10.0.1.1".parse().unwrap(); + let private_ip2: IpAddr = "192.168.1.1".parse().unwrap(); + let public_ip: IpAddr = "8.8.8.8".parse().unwrap(); + + assert!(config.is_private_network(&private_ip)); + assert!(config.is_private_network(&private_ip2)); + assert!(!config.is_private_network(&public_ip)); + } + + #[test] + fn test_parse_ip_list_from_env() { + use crate::config::env::parse_ip_list_from_env; + + // 测试有效的 IP 列表 + std::env::set_var("TEST_IP_LIST", "10.0.0.0/8,192.168.1.0/24"); + + let result = parse_ip_list_from_env("TEST_IP_LIST", ""); + assert!(result.is_ok()); + + let networks = result.unwrap(); + assert_eq!(networks.len(), 2); + + // 测试空值 + std::env::set_var("TEST_IP_LIST_EMPTY", ""); + let result = parse_ip_list_from_env("TEST_IP_LIST_EMPTY", ""); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + + // 测试无效值 + std::env::set_var("TEST_IP_LIST_INVALID", "invalid,10.0.0.0/8"); + let result = parse_ip_list_from_env("TEST_IP_LIST_INVALID", ""); + assert!(result.is_ok()); // 无效项会被跳过 + + // 清理环境变量 + std::env::remove_var("TEST_IP_LIST"); + std::env::remove_var("TEST_IP_LIST_EMPTY"); + std::env::remove_var("TEST_IP_LIST_INVALID"); + } + + #[test] + fn test_default_values() { + use crate::config::env::{ + DEFAULT_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_MAX_HOPS, DEFAULT_PROXY_VALIDATION_MODE, DEFAULT_TRUSTED_PROXIES, + }; + + assert_eq!(DEFAULT_TRUSTED_PROXIES, "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8"); + assert_eq!(DEFAULT_PROXY_VALIDATION_MODE, "hop_by_hop"); + assert_eq!(DEFAULT_PROXY_MAX_HOPS, 10); + assert!(DEFAULT_PROXY_ENABLE_RFC7239); + } +} diff --git a/crates/trusted-proxies/test/unit/ip_tests.rs b/crates/trusted-proxies/test/unit/ip_tests.rs new file mode 100644 index 00000000..735fbd8a --- /dev/null +++ b/crates/trusted-proxies/test/unit/ip_tests.rs @@ -0,0 +1,242 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! IP utility tests + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + use std::str::FromStr; + + use crate::utils::ip::IpUtils; + + #[test] + fn test_is_valid_ip_address() { + // 测试有效 IP + let valid_ip: IpAddr = "192.168.1.1".parse().unwrap(); + assert!(IpUtils::is_valid_ip_address(&valid_ip)); + + // 测试未指定地址 + let unspecified_ip: IpAddr = "0.0.0.0".parse().unwrap(); + assert!(!IpUtils::is_valid_ip_address(&unspecified_ip)); + + // 测试多播地址 + let multicast_ip: IpAddr = "224.0.0.1".parse().unwrap(); + assert!(!IpUtils::is_valid_ip_address(&multicast_ip)); + + // 测试 IPv6 + let valid_ipv6: IpAddr = "2001:db8::1".parse().unwrap(); + assert!(IpUtils::is_valid_ip_address(&valid_ipv6)); + + let unspecified_ipv6: IpAddr = "::".parse().unwrap(); + assert!(!IpUtils::is_valid_ip_address(&unspecified_ipv6)); + } + + #[test] + fn test_is_reserved_ip() { + // 测试私有地址 + let private_ip: IpAddr = "10.0.0.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&private_ip)); + + // 测试回环地址 + let loopback_ip: IpAddr = "127.0.0.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&loopback_ip)); + + // 测试链路本地地址 + let link_local_ip: IpAddr = "169.254.0.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&link_local_ip)); + + // 测试文档地址 + let documentation_ip: IpAddr = "192.0.2.1".parse().unwrap(); + assert!(IpUtils::is_reserved_ip(&documentation_ip)); + + // 测试公网地址 + let public_ip: IpAddr = "8.8.8.8".parse().unwrap(); + assert!(!IpUtils::is_reserved_ip(&public_ip)); + } + + #[test] + fn test_is_private_ip() { + // 测试 10.0.0.0/8 + assert!(IpUtils::is_private_ip(&"10.0.0.1".parse().unwrap())); + assert!(IpUtils::is_private_ip(&"10.255.255.254".parse().unwrap())); + + // 测试 172.16.0.0/12 + assert!(IpUtils::is_private_ip(&"172.16.0.1".parse().unwrap())); + assert!(IpUtils::is_private_ip(&"172.31.255.254".parse().unwrap())); + assert!(!IpUtils::is_private_ip(&"172.15.0.1".parse().unwrap())); + assert!(!IpUtils::is_private_ip(&"172.32.0.1".parse().unwrap())); + + // 测试 192.168.0.0/16 + assert!(IpUtils::is_private_ip(&"192.168.0.1".parse().unwrap())); + assert!(IpUtils::is_private_ip(&"192.168.255.254".parse().unwrap())); + + // 测试公网地址 + assert!(!IpUtils::is_private_ip(&"8.8.8.8".parse().unwrap())); + assert!(!IpUtils::is_private_ip(&"203.0.113.1".parse().unwrap())); + } + + #[test] + fn test_is_loopback_ip() { + // IPv4 回环地址 + assert!(IpUtils::is_loopback_ip(&"127.0.0.1".parse().unwrap())); + assert!(IpUtils::is_loopback_ip(&"127.255.255.254".parse().unwrap())); + + // IPv6 回环地址 + assert!(IpUtils::is_loopback_ip(&"::1".parse().unwrap())); + + // 非回环地址 + assert!(!IpUtils::is_loopback_ip(&"192.168.1.1".parse().unwrap())); + assert!(!IpUtils::is_loopback_ip(&"2001:db8::1".parse().unwrap())); + } + + #[test] + fn test_is_link_local_ip() { + // IPv4 链路本地地址 + assert!(IpUtils::is_link_local_ip(&"169.254.0.1".parse().unwrap())); + assert!(IpUtils::is_link_local_ip(&"169.254.255.254".parse().unwrap())); + + // IPv6 链路本地地址 + assert!(IpUtils::is_link_local_ip(&"fe80::1".parse().unwrap())); + assert!(IpUtils::is_link_local_ip(&"fe80::abcd:1234:5678:9abc".parse().unwrap())); + + // 非链路本地地址 + assert!(!IpUtils::is_link_local_ip(&"192.168.1.1".parse().unwrap())); + assert!(!IpUtils::is_link_local_ip(&"2001:db8::1".parse().unwrap())); + } + + #[test] + fn test_is_documentation_ip() { + // IPv4 文档地址 + assert!(IpUtils::is_documentation_ip(&"192.0.2.1".parse().unwrap())); + assert!(IpUtils::is_documentation_ip(&"198.51.100.1".parse().unwrap())); + assert!(IpUtils::is_documentation_ip(&"203.0.113.1".parse().unwrap())); + + // IPv6 文档地址 + assert!(IpUtils::is_documentation_ip(&"2001:db8::1".parse().unwrap())); + + // 非文档地址 + assert!(!IpUtils::is_documentation_ip(&"8.8.8.8".parse().unwrap())); + assert!(!IpUtils::is_documentation_ip(&"2001:4860::1".parse().unwrap())); + } + + #[test] + fn test_parse_ip_or_cidr() { + // 测试单个 IP + let result = IpUtils::parse_ip_or_cidr("192.168.1.1"); + assert!(result.is_ok()); + + // 测试 CIDR + let result = IpUtils::parse_ip_or_cidr("192.168.1.0/24"); + assert!(result.is_ok()); + + // 测试 IPv6 + let result = IpUtils::parse_ip_or_cidr("2001:db8::1"); + assert!(result.is_ok()); + + let result = IpUtils::parse_ip_or_cidr("2001:db8::/32"); + assert!(result.is_ok()); + + // 测试无效输入 + let result = IpUtils::parse_ip_or_cidr("invalid"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_ip_list() { + // 测试有效的 IP 列表 + let result = IpUtils::parse_ip_list("192.168.1.1, 10.0.0.1, 8.8.8.8"); + assert!(result.is_ok()); + + let ips = result.unwrap(); + assert_eq!(ips.len(), 3); + assert_eq!(ips[0], IpAddr::from_str("192.168.1.1").unwrap()); + assert_eq!(ips[1], IpAddr::from_str("10.0.0.1").unwrap()); + assert_eq!(ips[2], IpAddr::from_str("8.8.8.8").unwrap()); + + // 测试带空格的 IP 列表 + let result = IpUtils::parse_ip_list("192.168.1.1,10.0.0.1"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 2); + + // 测试空列表 + let result = IpUtils::parse_ip_list(""); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + + // 测试无效 IP + let result = IpUtils::parse_ip_list("192.168.1.1, invalid"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_network_list() { + // 测试有效的网络列表 + let result = IpUtils::parse_network_list("192.168.1.0/24, 10.0.0.0/8"); + assert!(result.is_ok()); + + let networks = result.unwrap(); + assert_eq!(networks.len(), 2); + + // 测试单个 IP(会被当作/32 或/128 网络) + let result = IpUtils::parse_network_list("192.168.1.1"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 1); + + // 测试无效网络 + let result = IpUtils::parse_network_list("192.168.1.0/24, invalid"); + assert!(result.is_err()); + } + + #[test] + fn test_ip_in_networks() { + let networks = vec!["10.0.0.0/8".parse().unwrap(), "192.168.1.0/24".parse().unwrap()]; + + let ip_in_network: IpAddr = "10.0.1.1".parse().unwrap(); + let ip_in_network2: IpAddr = "192.168.1.100".parse().unwrap(); + let ip_not_in_network: IpAddr = "8.8.8.8".parse().unwrap(); + + assert!(IpUtils::ip_in_networks(&ip_in_network, &networks)); + assert!(IpUtils::ip_in_networks(&ip_in_network2, &networks)); + assert!(!IpUtils::ip_in_networks(&ip_not_in_network, &networks)); + } + + #[test] + fn test_get_ip_type() { + assert_eq!(IpUtils::get_ip_type(&"10.0.0.1".parse().unwrap()), "private"); + assert_eq!(IpUtils::get_ip_type(&"127.0.0.1".parse().unwrap()), "loopback"); + assert_eq!(IpUtils::get_ip_type(&"169.254.0.1".parse().unwrap()), "link_local"); + assert_eq!(IpUtils::get_ip_type(&"192.0.2.1".parse().unwrap()), "documentation"); + assert_eq!(IpUtils::get_ip_type(&"224.0.0.1".parse().unwrap()), "reserved"); + assert_eq!(IpUtils::get_ip_type(&"8.8.8.8".parse().unwrap()), "public"); + } + + #[test] + fn test_canonical_ip() { + // 测试 IPv4 + let ipv4: IpAddr = "192.168.001.001".parse().unwrap(); + assert_eq!(IpUtils::canonical_ip(&ipv4), "192.168.1.1"); + + // 测试 IPv6 压缩 + let ipv6_full: IpAddr = "2001:0db8:0000:0000:0000:0000:0000:0001".parse().unwrap(); + let ipv6_compressed: IpAddr = "2001:db8::1".parse().unwrap(); + + assert_eq!(IpUtils::canonical_ip(&ipv6_full), "2001:db8::1"); + assert_eq!(IpUtils::canonical_ip(&ipv6_compressed), "2001:db8::1"); + + // 测试包含多个零段的 IPv6 + let ipv6_multi_zero: IpAddr = "2001:0db8:0000:0000:abcd:0000:0000:1234".parse().unwrap(); + assert_eq!(IpUtils::canonical_ip(&ipv6_multi_zero), "2001:db8::abcd:0:0:1234"); + } +} diff --git a/crates/trusted-proxies/test/unit/mod.rs b/crates/trusted-proxies/test/unit/mod.rs index 6238cfff..92a5b076 100644 --- a/crates/trusted-proxies/test/unit/mod.rs +++ b/crates/trusted-proxies/test/unit/mod.rs @@ -11,3 +11,16 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Unit tests for the trusted proxy system + +mod config_tests; +mod ip_tests; +mod validation_tests; +mod validator_tests; + +// 重新导出测试模块 +pub use config_tests::*; +pub use ip_tests::*; +pub use validation_tests::*; +pub use validator_tests::*; diff --git a/crates/trusted-proxies/test/unit/validation_tests.rs b/crates/trusted-proxies/test/unit/validation_tests.rs new file mode 100644 index 00000000..44954b93 --- /dev/null +++ b/crates/trusted-proxies/test/unit/validation_tests.rs @@ -0,0 +1,637 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Validation utility unit tests + +#[cfg(test)] +mod tests { + use http::HeaderMap; + use std::net::IpAddr; + + use crate::utils::ip::IpUtils; + use crate::utils::validation::ValidationUtils; + + /// 测试电子邮件验证 + #[test] + fn test_email_validation() { + // 有效的电子邮件地址 + assert!(ValidationUtils::is_valid_email("user@example.com")); + assert!(ValidationUtils::is_valid_email("first.last@example.co.uk")); + assert!(ValidationUtils::is_valid_email("user123@example.org")); + assert!(ValidationUtils::is_valid_email("user+tag@example.com")); + assert!(ValidationUtils::is_valid_email("user_name@example-domain.com")); + + // 无效的电子邮件地址 + assert!(!ValidationUtils::is_valid_email("")); + assert!(!ValidationUtils::is_valid_email("invalid-email")); + assert!(!ValidationUtils::is_valid_email("user@")); + assert!(!ValidationUtils::is_valid_email("@example.com")); + assert!(!ValidationUtils::is_valid_email("user@.com")); + assert!(!ValidationUtils::is_valid_email("user@example.")); + assert!(!ValidationUtils::is_valid_email("user@example..com")); + assert!(!ValidationUtils::is_valid_email("user@example_com")); + assert!(!ValidationUtils::is_valid_email("user@[127.0.0.1]")); + assert!(!ValidationUtils::is_valid_email("user name@example.com")); + assert!(!ValidationUtils::is_valid_email("user@exa mple.com")); + assert!(!ValidationUtils::is_valid_email("user@example.c")); + } + + /// 测试 URL 验证 + #[test] + fn test_url_validation() { + // 有效的 URL + assert!(ValidationUtils::is_valid_url("https://example.com")); + assert!(ValidationUtils::is_valid_url("http://example.com")); + assert!(ValidationUtils::is_valid_url("example.com")); + assert!(ValidationUtils::is_valid_url("sub.example.com")); + assert!(ValidationUtils::is_valid_url("example.co.uk")); + assert!(ValidationUtils::is_valid_url("example.com/path")); + assert!(ValidationUtils::is_valid_url("example.com/path/to/resource")); + assert!(ValidationUtils::is_valid_url("example.com/?query=param")); + assert!(ValidationUtils::is_valid_url("sub-domain.example-domain.com")); + + // 无效的 URL + assert!(!ValidationUtils::is_valid_url("")); + assert!(!ValidationUtils::is_valid_url("invalid")); + assert!(!ValidationUtils::is_valid_url("example")); + assert!(!ValidationUtils::is_valid_url("example.")); + assert!(!ValidationUtils::is_valid_url(".com")); + assert!(!ValidationUtils::is_valid_url("http://")); + assert!(!ValidationUtils::is_valid_url("https://")); + assert!(!ValidationUtils::is_valid_url("://example.com")); + assert!(!ValidationUtils::is_valid_url("example..com")); + assert!(!ValidationUtils::is_valid_url("-example.com")); + assert!(!ValidationUtils::is_valid_url("example-.com")); + assert!(!ValidationUtils::is_valid_url("example_com")); + } + + /// 测试 X-Forwarded-For 头部验证 + #[test] + fn test_x_forwarded_for_validation() { + // 有效的 X-Forwarded-For 头部 + assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195")); + assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195, 198.51.100.1")); + assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195,198.51.100.1,10.0.1.100")); + assert!(ValidationUtils::validate_x_forwarded_for("2001:db8::1")); + assert!(ValidationUtils::validate_x_forwarded_for("2001:db8::1, 2001:db8::2")); + assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195:8080, 198.51.100.1:443")); // 带端口 + + // 无效的 X-Forwarded-For 头部 + assert!(!ValidationUtils::validate_x_forwarded_for("")); // 空字符串 + assert!(!ValidationUtils::validate_x_forwarded_for(" ")); // 只有空格 + assert!(!ValidationUtils::validate_x_forwarded_for("invalid")); // 无效 IP + assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195, invalid")); // 部分无效 + assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195, ")); // 尾部逗号加空格 + assert!(!ValidationUtils::validate_x_forwarded_for(",203.0.113.195")); // 开头逗号 + assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195,,198.51.100.1")); // 连续逗号 + assert!(!ValidationUtils::validate_x_forwarded_for("256.256.256.256")); // 超出范围的 IP + } + + /// 测试 Forwarded 头部验证 (RFC 7239) + #[test] + fn test_forwarded_header_validation() { + // 有效的 Forwarded 头部 + assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60")); + assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60;proto=http")); + assert!(ValidationUtils::validate_forwarded_header("for=\"[2001:db8:cafe::17]\";proto=https")); + assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.43, for=198.51.100.17")); + assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60;proto=http;by=203.0.113.43")); + assert!(ValidationUtils::validate_forwarded_header( + "by=203.0.113.43;for=192.0.2.60;host=example.com;proto=https" + )); + + // 无效的 Forwarded 头部 + assert!(!ValidationUtils::validate_forwarded_header("")); // 空字符串 + assert!(!ValidationUtils::validate_forwarded_header(" ")); // 只有空格 + assert!(!ValidationUtils::validate_forwarded_header("invalid")); // 无效格式 + assert!(!ValidationUtils::validate_forwarded_header("for=192.0.2.60 proto=http")); // 缺少分号 + assert!(!ValidationUtils::validate_forwarded_header("for;192.0.2.60")); // 缺少等号 + assert!(!ValidationUtils::validate_forwarded_header("=192.0.2.60")); // 缺少键 + assert!(!ValidationUtils::validate_forwarded_header("for=")); // 缺少值 + } + + /// 测试 IP 范围验证 + #[test] + fn test_ip_in_range_validation() { + let cidr_ranges = vec![ + "10.0.0.0/8".to_string(), + "192.168.0.0/16".to_string(), + "172.16.0.0/12".to_string(), + "2001:db8::/32".to_string(), + ]; + + // IP 在范围内 + let ip_in_range: IpAddr = "10.0.1.1".parse().unwrap(); + assert!(ValidationUtils::validate_ip_in_range(&ip_in_range, &cidr_ranges)); + + let ip_in_range2: IpAddr = "192.168.1.100".parse().unwrap(); + assert!(ValidationUtils::validate_ip_in_range(&ip_in_range2, &cidr_ranges)); + + let ip_in_range3: IpAddr = "172.16.0.1".parse().unwrap(); + assert!(ValidationUtils::validate_ip_in_range(&ip_in_range3, &cidr_ranges)); + + let ipv6_in_range: IpAddr = "2001:db8::1".parse().unwrap(); + assert!(ValidationUtils::validate_ip_in_range(&ipv6_in_range, &cidr_ranges)); + + // IP 不在范围内 + let ip_not_in_range: IpAddr = "8.8.8.8".parse().unwrap(); + assert!(!ValidationUtils::validate_ip_in_range(&ip_not_in_range, &cidr_ranges)); + + let ip_not_in_range2: IpAddr = "203.0.113.1".parse().unwrap(); + assert!(!ValidationUtils::validate_ip_in_range(&ip_not_in_range2, &cidr_ranges)); + + let ipv6_not_in_range: IpAddr = "2001:4860::1".parse().unwrap(); + assert!(!ValidationUtils::validate_ip_in_range(&ipv6_not_in_range, &cidr_ranges)); + + // 空范围列表 + assert!(!ValidationUtils::validate_ip_in_range(&ip_in_range, &Vec::new())); + + // 无效的 CIDR 范围(应该被忽略) + let invalid_ranges = vec![ + "invalid".to_string(), + "10.0.0.0/8".to_string(), // 这个有效 + ]; + + let test_ip: IpAddr = "10.0.1.1".parse().unwrap(); + // 即使有无效范围,只要有一个有效范围包含 IP,就应该返回 true + assert!(ValidationUtils::validate_ip_in_range(&test_ip, &invalid_ranges)); + } + + /// 测试头部值验证 + #[test] + fn test_header_value_validation() { + // 有效的头部值 + assert!(ValidationUtils::validate_header_value("text/plain")); + assert!(ValidationUtils::validate_header_value("application/json; charset=utf-8")); + assert!(ValidationUtils::validate_header_value("Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9")); + assert!(ValidationUtils::validate_header_value("127.0.0.1:8080")); + assert!(ValidationUtils::validate_header_value("")); // 空字符串是有效的 + assert!(ValidationUtils::validate_header_value("normal text with spaces")); // 普通文本 + assert!(ValidationUtils::validate_header_value("value\twith\ttabs")); // 包含制表符 + + // 长但有效的头部值(边界情况) + let long_value = "a".repeat(8192); + assert!(ValidationUtils::validate_header_value(&long_value)); + + // 无效的头部值 + let too_long_value = "a".repeat(8193); // 超过长度限制 + assert!(!ValidationUtils::validate_header_value(&too_long_value)); + + // 包含控制字符(除了制表符、换行符、回车符) + assert!(!ValidationUtils::validate_header_value("value\x00with_null")); // 空字符 + assert!(!ValidationUtils::validate_header_value("value\x01with_soh")); // 标题开始 + assert!(!ValidationUtils::validate_header_value("value\x1fwith_us")); // 单元分隔符 + + // 换行符和回车符是允许的(在某些上下文中) + assert!(ValidationUtils::validate_header_value("line1\nline2")); + assert!(ValidationUtils::validate_header_value("line1\r\nline2")); + } + + /// 测试头部映射验证 + #[test] + fn test_headers_validation() { + let mut valid_headers = HeaderMap::new(); + valid_headers.insert("Content-Type", "application/json".parse().unwrap()); + valid_headers.insert("Authorization", "Bearer token123".parse().unwrap()); + valid_headers.insert("X-Forwarded-For", "203.0.113.195".parse().unwrap()); + + assert!(ValidationUtils::validate_headers(&valid_headers)); + + // 测试头部名称过长 + let mut invalid_headers = HeaderMap::new(); + let long_name = "X-".to_string() + &"A".repeat(300); // 超过 256 字符 + invalid_headers.insert(long_name, "value".parse().unwrap()); + + assert!(!ValidationUtils::validate_headers(&invalid_headers)); + + // 测试头部值过长 + let mut invalid_headers2 = HeaderMap::new(); + let long_value = "A".repeat(8193); // 超过 8192 字节 + invalid_headers2.insert("X-Custom-Header", long_value.parse().unwrap()); + + assert!(!ValidationUtils::validate_headers(&invalid_headers2)); + + // 测试二进制数据(无法转换为字符串) + let mut binary_headers = HeaderMap::new(); + let binary_data = vec![0x00, 0x01, 0x02, 0x03]; + binary_headers.insert("X-Binary-Data", http::HeaderValue::from_bytes(&binary_data).unwrap()); + + // 二进制数据应该通过验证(只要长度不超过限制) + assert!(ValidationUtils::validate_headers(&binary_headers)); + + // 测试过长的二进制数据 + let mut long_binary_headers = HeaderMap::new(); + let long_binary_data = vec![0x00; 8193]; // 超过 8192 字节 + long_binary_headers.insert("X-Long-Binary", http::HeaderValue::from_bytes(&long_binary_data).unwrap()); + + assert!(!ValidationUtils::validate_headers(&long_binary_headers)); + } + + /// 测试端口号验证 + #[test] + fn test_port_validation() { + // 有效端口号 + assert!(ValidationUtils::validate_port(1)); + assert!(ValidationUtils::validate_port(80)); + assert!(ValidationUtils::validate_port(443)); + assert!(ValidationUtils::validate_port(8080)); + assert!(ValidationUtils::validate_port(65535)); + + // 无效端口号 + assert!(!ValidationUtils::validate_port(0)); // 端口 0 是保留的 + assert!(!ValidationUtils::validate_port(65536)); // 超过最大值 + assert!(!ValidationUtils::validate_port(70000)); // 远超过最大值 + } + + /// 测试 CIDR 表示法验证 + #[test] + fn test_cidr_validation() { + // 有效的 CIDR 表示法 + assert!(ValidationUtils::validate_cidr("192.168.1.0/24")); + assert!(ValidationUtils::validate_cidr("10.0.0.0/8")); + assert!(ValidationUtils::validate_cidr("0.0.0.0/0")); // 默认路由 + assert!(ValidationUtils::validate_cidr("2001:db8::/32")); + assert!(ValidationUtils::validate_cidr("::/0")); // IPv6 默认路由 + assert!(ValidationUtils::validate_cidr("192.168.1.1/32")); // 单个主机 + assert!(ValidationUtils::validate_cidr("2001:db8::1/128")); // 单个 IPv6 主机 + + // 无效的 CIDR 表示法 + assert!(!ValidationUtils::validate_cidr("")); // 空字符串 + assert!(!ValidationUtils::validate_cidr("invalid")); // 无效格式 + assert!(!ValidationUtils::validate_cidr("192.168.1.0")); // 缺少前缀长度 + assert!(!ValidationUtils::validate_cidr("192.168.1.0/33")); // 前缀长度过大 + assert!(!ValidationUtils::validate_cidr("256.256.256.256/24")); // 无效 IP + assert!(!ValidationUtils::validate_cidr("192.168.1.0/24/extra")); // 多余的部分 + assert!(!ValidationUtils::validate_cidr("192.168.1.0/-1")); // 负的前缀长度 + assert!(!ValidationUtils::validate_cidr("192.168.1.0/abc")); // 非数字前缀长度 + } + + /// 测试代理链长度验证 + #[test] + fn test_proxy_chain_length_validation() { + let chain = vec![ + "203.0.113.195".parse().unwrap(), + "198.51.100.1".parse().unwrap(), + "10.0.1.100".parse().unwrap(), + ]; + + // 链长度在限制内 + assert!(ValidationUtils::validate_proxy_chain_length(&chain, 3)); + assert!(ValidationUtils::validate_proxy_chain_length(&chain, 5)); + assert!(ValidationUtils::validate_proxy_chain_length(&chain, 10)); + + // 链长度超过限制 + assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 2)); + assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 1)); + assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 0)); + + // 空链 + let empty_chain: Vec = Vec::new(); + assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 0)); + assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 1)); + assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 10)); + } + + /// 测试代理链连续性验证 + #[test] + fn test_proxy_chain_continuity_validation() { + // 连续链(无重复相邻 IP) + let continuous_chain = vec![ + "203.0.113.195".parse().unwrap(), + "198.51.100.1".parse().unwrap(), + "10.0.1.100".parse().unwrap(), + ]; + assert!(ValidationUtils::validate_proxy_chain_continuity(&continuous_chain)); + + // 不连续链(有重复相邻 IP) + let discontinuous_chain = vec![ + "203.0.113.195".parse().unwrap(), + "198.51.100.1".parse().unwrap(), + "198.51.100.1".parse().unwrap(), // 重复 + "10.0.1.100".parse().unwrap(), + ]; + assert!(!ValidationUtils::validate_proxy_chain_continuity(&discontinuous_chain)); + + // 短链(应该总是连续的) + let short_chain = vec!["203.0.113.195".parse().unwrap()]; + assert!(ValidationUtils::validate_proxy_chain_continuity(&short_chain)); + + let two_item_chain = vec!["203.0.113.195".parse().unwrap(), "198.51.100.1".parse().unwrap()]; + assert!(ValidationUtils::validate_proxy_chain_continuity(&two_item_chain)); + + // 空链(应该总是连续的) + let empty_chain: Vec = Vec::new(); + assert!(ValidationUtils::validate_proxy_chain_continuity(&empty_chain)); + + // 有多个重复的链 + let multi_duplicate_chain = vec![ + "203.0.113.195".parse().unwrap(), + "203.0.113.195".parse().unwrap(), // 重复 1 + "198.51.100.1".parse().unwrap(), + "198.51.100.1".parse().unwrap(), // 重复 2 + ]; + assert!(!ValidationUtils::validate_proxy_chain_continuity(&multi_duplicate_chain)); + } + + /// 测试安全字符串验证 + #[test] + fn test_safe_string_validation() { + // 安全字符串 + assert!(ValidationUtils::is_safe_string("example")); + assert!(ValidationUtils::is_safe_string("example123")); + assert!(ValidationUtils::is_safe_string("example-test")); + assert!(ValidationUtils::is_safe_string("example.test")); + assert!(ValidationUtils::is_safe_string("example~test")); + assert!(ValidationUtils::is_safe_string("http://example.com/path")); + assert!(ValidationUtils::is_safe_string("https://example.com/?query=param")); + assert!(ValidationUtils::is_safe_string("user@example.com")); + assert!(ValidationUtils::is_safe_string("192.168.1.1:8080")); + assert!(ValidationUtils::is_safe_string("[2001:db8::1]:8080")); + + // 不安全字符串 + assert!(!ValidationUtils::is_safe_string("")); // 空字符串 + assert!(!ValidationUtils::is_safe_string("example test")); // 包含空格 + assert!(!ValidationUtils::is_safe_string("example\ttest")); // 包含制表符 + assert!(!ValidationUtils::is_safe_string("example\ntest")); // 包含换行符 + assert!(!ValidationUtils::is_safe_string("example")); // 包含尖括号 + assert!(!ValidationUtils::is_safe_string("example\"test")); // 包含双引号 + assert!(!ValidationUtils::is_safe_string("example'test")); // 包含单引号 + assert!(!ValidationUtils::is_safe_string("example\\test")); // 包含反斜杠 + assert!(!ValidationUtils::is_safe_string("example`test")); // 包含反引号 + assert!(!ValidationUtils::is_safe_string("example|test")); // 包含竖线 + assert!(!ValidationUtils::is_safe_string("example$test")); // 包含美元符号 + assert!(!ValidationUtils::is_safe_string("example%test")); // 包含百分号 + assert!(!ValidationUtils::is_safe_string("example^test")); // 包含脱字符 + assert!(!ValidationUtils::is_safe_string("example&test")); // 包含和号 + assert!(!ValidationUtils::is_safe_string("example(test")); // 包含括号 + assert!(!ValidationUtils::is_safe_string("example)test")); // 包含括号 + assert!(!ValidationUtils::is_safe_string("example[test")); // 包含方括号 + assert!(!ValidationUtils::is_safe_string("example]test")); // 包含方括号 + assert!(!ValidationUtils::is_safe_string("example{test")); // 包含花括号 + assert!(!ValidationUtils::is_safe_string("example}test")); // 包含花括号 + } + + /// 测试速率限制参数验证 + #[test] + fn test_rate_limit_params_validation() { + // 有效的速率限制参数 + assert!(ValidationUtils::validate_rate_limit_params(1, 1)); // 最小值 + assert!(ValidationUtils::validate_rate_limit_params(100, 60)); // 典型值 + assert!(ValidationUtils::validate_rate_limit_params(10000, 86400)); // 最大值 + + // 无效的速率限制参数 + assert!(!ValidationUtils::validate_rate_limit_params(0, 60)); // 请求数为 0 + assert!(!ValidationUtils::validate_rate_limit_params(10001, 60)); // 请求数超过最大值 + assert!(!ValidationUtils::validate_rate_limit_params(100, 0)); // 周期为 0 + assert!(!ValidationUtils::validate_rate_limit_params(100, 86401)); // 周期超过最大值 + assert!(!ValidationUtils::validate_rate_limit_params(0, 0)); // 两者都为 0 + assert!(!ValidationUtils::validate_rate_limit_params(100001, 100000)); // 两者都超过最大值 + } + + /// 测试缓存参数验证 + #[test] + fn test_cache_params_validation() { + // 有效的缓存参数 + assert!(ValidationUtils::validate_cache_params(1, 1)); // 最小值 + assert!(ValidationUtils::validate_cache_params(10000, 300)); // 典型值 + assert!(ValidationUtils::validate_cache_params(1000000, 86400)); // 最大值 + + // 无效的缓存参数 + assert!(!ValidationUtils::validate_cache_params(0, 300)); // 容量为 0 + assert!(!ValidationUtils::validate_cache_params(1000001, 300)); // 容量超过最大值 + assert!(!ValidationUtils::validate_cache_params(10000, 0)); // TTL 为 0 + assert!(!ValidationUtils::validate_cache_params(10000, 86401)); // TTL 超过最大值 + assert!(!ValidationUtils::validate_cache_params(0, 0)); // 两者都为 0 + assert!(!ValidationUtils::validate_cache_params(2000000, 100000)); // 两者都超过最大值 + } + + /// 测试敏感数据脱敏 + #[test] + fn test_sensitive_data_masking() { + let sensitive_patterns = vec!["password", "token", "secret", "authorization", "api_key"]; + + // 测试各种敏感字段的脱敏 + let test_cases = vec![ + ( + r#"{"username":"john","password":"secret123"}"#, + r#"{"username":"john","password:[REDACTED]"}"#, + ), + (r#"token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9&user=john"#, r#"token:[REDACTED]&user=john"#), + ( + r#"Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"#, + r#"Authorization:[REDACTED]"#, + ), + (r#"api_key=sk_test_1234567890abcdef"#, r#"api_key:[REDACTED]"#), + (r#"secret_key=abc123&public_key=xyz789"#, r#"secret_key:[REDACTED]&public_key=xyz789"#), + ( + r#"password=123&password_confirmation=123"#, + r#"password:[REDACTED]&password_confirmation:[REDACTED]"#, + ), + ]; + + for (input, expected) in test_cases { + let result = ValidationUtils::mask_sensitive_data(input, &sensitive_patterns); + assert_eq!(result, expected, "Failed to mask: {}", input); + } + + // 测试不包含敏感数据的情况 + let safe_data = r#"{"name":"John","age":30,"city":"New York"}"#; + let result = ValidationUtils::mask_sensitive_data(safe_data, &sensitive_patterns); + assert_eq!(result, safe_data); + + // 测试空模式列表 + let sensitive_data = r#"password=secret123"#; + let result = ValidationUtils::mask_sensitive_data(sensitive_data, &Vec::new()); + assert_eq!(result, sensitive_data); + + // 测试空输入 + let result = ValidationUtils::mask_sensitive_data("", &sensitive_patterns); + assert_eq!(result, ""); + } + + /// 测试组合验证场景 + #[test] + fn test_combined_validation_scenarios() { + // 场景 1:完整的代理请求验证 + let proxy_chain = vec![ + "203.0.113.195".parse().unwrap(), + "198.51.100.1".parse().unwrap(), + "10.0.1.100".parse().unwrap(), + ]; + + assert!(ValidationUtils::validate_proxy_chain_length(&proxy_chain, 10)); + assert!(ValidationUtils::validate_proxy_chain_continuity(&proxy_chain)); + + // 场景 2:包含无效数据的头部验证 + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + headers.insert("X-Forwarded-For", "203.0.113.195, invalid, 10.0.1.100".parse().unwrap()); + + // 头部映射本身是有效的(即使包含无效的 X-Forwarded-For) + assert!(ValidationUtils::validate_headers(&headers)); + + // 但 X-Forwarded-For 内容无效 + let xff_value = headers.get("X-Forwarded-For").unwrap().to_str().unwrap(); + assert!(!ValidationUtils::validate_x_forwarded_for(xff_value)); + + // 场景 3:配置参数验证组合 + let cache_capacity = 10000; + let cache_ttl = 300; + let rate_limit_requests = 100; + let rate_limit_period = 60; + + assert!(ValidationUtils::validate_cache_params(cache_capacity, cache_ttl)); + assert!(ValidationUtils::validate_rate_limit_params(rate_limit_requests, rate_limit_period)); + + // 场景 4:IP 和 CIDR 验证组合 + let ip: IpAddr = "10.0.1.1".parse().unwrap(); + let cidr = "10.0.0.0/8"; + + assert!(IpUtils::is_private_ip(&ip)); + assert!(ValidationUtils::validate_cidr(cidr)); + assert!(ValidationUtils::validate_ip_in_range(&ip, &[cidr.to_string()])); + } + + /// 测试边缘情况和边界值 + #[test] + fn test_edge_cases_and_boundaries() { + // 测试头部值的边界长度 + let max_length_value = "a".repeat(8192); + let over_length_value = "a".repeat(8193); + + assert!(ValidationUtils::validate_header_value(&max_length_value)); + assert!(!ValidationUtils::validate_header_value(&over_length_value)); + + // 测试端口边界值 + assert!(!ValidationUtils::validate_port(0)); // 最小无效值 + assert!(ValidationUtils::validate_port(1)); // 最小有效值 + assert!(ValidationUtils::validate_port(65535)); // 最大有效值 + assert!(!ValidationUtils::validate_port(65536)); // 超过最大值 + + // 测试 CIDR 前缀长度边界值 + assert!(ValidationUtils::validate_cidr("192.168.1.0/0")); // 最小有效前缀 + assert!(ValidationUtils::validate_cidr("192.168.1.0/32")); // 最大有效前缀 + assert!(!ValidationUtils::validate_cidr("192.168.1.0/33")); // 超过最大值 + + // IPv6 CIDR 前缀长度 + assert!(ValidationUtils::validate_cidr("2001:db8::/0")); // 最小有效前缀 + assert!(ValidationUtils::validate_cidr("2001:db8::/128")); // 最大有效前缀 + + // 测试代理链边界情况 + let empty_chain: Vec = Vec::new(); + assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 0)); + assert!(ValidationUtils::validate_proxy_chain_continuity(&empty_chain)); + + // 测试单个 IP 的链 + let single_ip_chain = vec!["192.168.1.1".parse().unwrap()]; + assert!(ValidationUtils::validate_proxy_chain_length(&single_ip_chain, 1)); + assert!(ValidationUtils::validate_proxy_chain_continuity(&single_ip_chain)); + + // 测试速率限制边界值 + assert!(!ValidationUtils::validate_rate_limit_params(0, 60)); + assert!(ValidationUtils::validate_rate_limit_params(1, 1)); + assert!(ValidationUtils::validate_rate_limit_params(10000, 86400)); + assert!(!ValidationUtils::validate_rate_limit_params(10001, 86400)); + assert!(!ValidationUtils::validate_rate_limit_params(10000, 86401)); + + // 测试缓存参数边界值 + assert!(!ValidationUtils::validate_cache_params(0, 300)); + assert!(ValidationUtils::validate_cache_params(1, 1)); + assert!(ValidationUtils::validate_cache_params(1000000, 86400)); + assert!(!ValidationUtils::validate_cache_params(1000001, 86400)); + assert!(!ValidationUtils::validate_cache_params(1000000, 86401)); + } + + /// 测试性能敏感场景 + #[test] + fn test_performance_sensitive_scenarios() { + // 测试长代理链的处理 + let mut long_chain = Vec::new(); + for i in 0..100 { + let ip = format!("10.0.{}.1", i % 256).parse().unwrap(); + long_chain.push(ip); + } + + // 应该能快速处理长链 + assert!(ValidationUtils::validate_proxy_chain_length(&long_chain, 100)); + assert!(ValidationUtils::validate_proxy_chain_continuity(&long_chain)); + + // 测试大量 CIDR 范围的验证 + let mut cidr_ranges = Vec::new(); + for i in 0..1000 { + let cidr = format!("10.{}.0.0/16", i % 256); + cidr_ranges.push(cidr); + } + + let test_ip: IpAddr = "10.128.1.1".parse().unwrap(); + // 应该能快速在大范围列表中查找 + let start = std::time::Instant::now(); + let result = ValidationUtils::validate_ip_in_range(&test_ip, &cidr_ranges); + let duration = start.elapsed(); + + assert!(result); + // 验证时间应该在合理范围内(比如小于 10 毫秒) + assert!(duration < std::time::Duration::from_millis(10)); + + // 测试头部值验证的性能 + let large_header_value = "x".repeat(10000); // 超过 8192,应该快速拒绝 + let start = std::time::Instant::now(); + let result = ValidationUtils::validate_header_value(&large_header_value); + let duration = start.elapsed(); + + assert!(!result); // 应该拒绝 + assert!(duration < std::time::Duration::from_millis(1)); // 应该非常快 + } + + /// 测试实际代理场景模拟 + #[test] + fn test_real_world_proxy_scenarios() { + // 场景 1:典型的反向代理配置 + let typical_xff = "203.0.113.195, 198.51.100.1, 10.0.1.100"; + assert!(ValidationUtils::validate_x_forwarded_for(typical_xff)); + + let typical_proxy_chain: Vec = typical_xff.split(',').map(|s| s.trim().parse().unwrap()).collect(); + + assert_eq!(typical_proxy_chain.len(), 3); + assert!(ValidationUtils::validate_proxy_chain_length(&typical_proxy_chain, 10)); + assert!(ValidationUtils::validate_proxy_chain_continuity(&typical_proxy_chain)); + + // 场景 2:负载均衡器场景 + let lb_scenario = "2001:db8::1, 203.0.113.195, 198.51.100.1"; + assert!(ValidationUtils::validate_x_forwarded_for(lb_scenario)); + + // 场景 3:可能被攻击的头部 + let attack_headers = vec![ + ("X-Forwarded-For", "127.0.0.1, 8.8.8.8, 192.168.1.1"), + ("X-Real-IP", "8.8.8.8"), + ("X-Forwarded-Host", "evil.com"), + ]; + + let mut headers = HeaderMap::new(); + for (name, value) in attack_headers { + headers.insert(name, value.parse().unwrap()); + } + + // 头部格式本身应该是有效的 + assert!(ValidationUtils::validate_headers(&headers)); + + // 但内容可能需要进一步验证 + let xff_value = headers.get("X-Forwarded-For").unwrap().to_str().unwrap(); + assert!(ValidationUtils::validate_x_forwarded_for(xff_value)); + + // 场景 4:RFC 7239 格式 + let rfc7239_header = "for=192.0.2.60;proto=https;by=203.0.113.43"; + assert!(ValidationUtils::validate_forwarded_header(rfc7239_header)); + } +} diff --git a/crates/trusted-proxies/test/unit/validator_tests.rs b/crates/trusted-proxies/test/unit/validator_tests.rs index 6238cfff..2094d727 100644 --- a/crates/trusted-proxies/test/unit/validator_tests.rs +++ b/crates/trusted-proxies/test/unit/validator_tests.rs @@ -11,3 +11,215 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Proxy validator unit tests + +#[cfg(test)] +mod tests { + use std::net::{IpAddr, SocketAddr}; + use std::str::FromStr; + + use axum::http::HeaderMap; + + use crate::config::{TrustedProxy, TrustedProxyConfig, ValidationMode}; + use crate::proxy::chain::ProxyChainAnalyzer; + use crate::proxy::validator::{ClientInfo, ProxyValidator}; + + fn create_test_config() -> TrustedProxyConfig { + let proxies = vec![ + TrustedProxy::Single("192.168.1.100".parse().unwrap()), + TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()), + TrustedProxy::Cidr("172.16.0.0/12".parse().unwrap()), + ]; + + TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 5, true, vec![]) + } + + #[test] + fn test_client_info_direct() { + let addr = SocketAddr::new(IpAddr::from([192, 168, 1, 1]), 8080); + let client_info = ClientInfo::direct(addr); + + assert_eq!(client_info.real_ip, IpAddr::from([192, 168, 1, 1])); + assert!(client_info.forwarded_host.is_none()); + assert!(client_info.forwarded_proto.is_none()); + assert!(!client_info.is_from_trusted_proxy); + assert!(client_info.proxy_ip.is_none()); + assert_eq!(client_info.proxy_hops, 0); + assert_eq!(client_info.validation_mode, ValidationMode::Lenient); + assert!(client_info.warnings.is_empty()); + } + + #[test] + fn test_parse_x_forwarded_for() { + use crate::proxy::validator::ProxyValidator; + + // 测试有效的 X-Forwarded-For 头部 + let header_value = "203.0.113.195, 198.51.100.1, 10.0.1.100"; + let result = ProxyValidator::parse_x_forwarded_for(header_value); + + assert_eq!(result.len(), 3); + assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap()); + assert_eq!(result[1], IpAddr::from_str("198.51.100.1").unwrap()); + assert_eq!(result[2], IpAddr::from_str("10.0.1.100").unwrap()); + + // 测试带端口的 IP + let header_value_with_ports = "203.0.113.195:8080, 198.51.100.1:443"; + let result = ProxyValidator::parse_x_forwarded_for(header_value_with_ports); + + assert_eq!(result.len(), 2); + assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap()); + assert_eq!(result[1], IpAddr::from_str("198.51.100.1").unwrap()); + + // 测试空值 + let empty_result = ProxyValidator::parse_x_forwarded_for(""); + assert!(empty_result.is_empty()); + + // 测试无效 IP + let invalid_result = ProxyValidator::parse_x_forwarded_for("invalid, 203.0.113.195"); + assert_eq!(invalid_result.len(), 1); // 无效项被跳过 + } + + #[test] + fn test_proxy_chain_analyzer_lenient() { + let config = create_test_config(); + let analyzer = ProxyChainAnalyzer::new(config.clone()); + + // 测试链:客户端 -> 可信代理 1 -> 可信代理 2 + let chain = vec![ + IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 + IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1 + IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 + ]; + + let current_proxy = IpAddr::from_str("192.168.1.100").unwrap(); + let mut headers = HeaderMap::new(); + + let result = analyzer.analyze_chain(&chain, current_proxy, &headers); + assert!(result.is_ok()); + + let analysis = result.unwrap(); + assert_eq!(analysis.client_ip, IpAddr::from_str("203.0.113.195").unwrap()); + assert_eq!(analysis.hops, 3); + assert!(analysis.is_continuous); + assert_eq!(analysis.validation_mode, ValidationMode::HopByHop); + } + + #[test] + fn test_proxy_chain_analyzer_strict() { + let mut config = create_test_config(); + config.validation_mode = ValidationMode::Strict; + + let analyzer = ProxyChainAnalyzer::new(config); + + // 测试链:客户端 -> 可信代理 1 -> 可信代理 2 (全部可信) + let chain = vec![ + IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 + IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1 + IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 + ]; + + let current_proxy = IpAddr::from_str("192.168.1.100").unwrap(); + let mut headers = HeaderMap::new(); + + let result = analyzer.analyze_chain(&chain, current_proxy, &headers); + assert!(result.is_ok()); + + // 测试链包含不可信代理 + let chain_with_untrusted = vec![ + IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 + IpAddr::from_str("8.8.8.8").unwrap(), // 不可信代理 + IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 + ]; + + let result = analyzer.analyze_chain(&chain_with_untrusted, current_proxy, &headers); + assert!(result.is_err()); + } + + #[test] + fn test_proxy_chain_analyzer_hop_by_hop() { + let config = create_test_config(); + let analyzer = ProxyChainAnalyzer::new(config); + + // 测试链:客户端 -> 不可信代理 -> 可信代理 1 -> 可信代理 2 + let chain = vec![ + IpAddr::from_str("203.0.113.195").unwrap(), // 客户端 + IpAddr::from_str("8.8.8.8").unwrap(), // 不可信代理 + IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1 + IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2 + ]; + + let current_proxy = IpAddr::from_str("192.168.1.100").unwrap(); + let mut headers = HeaderMap::new(); + + let result = analyzer.analyze_chain(&chain, current_proxy, &headers); + assert!(result.is_ok()); + + let analysis = result.unwrap(); + // 应该找到客户端 IP (203.0.113.195) + assert_eq!(analysis.client_ip, IpAddr::from_str("203.0.113.195").unwrap()); + // 应该验证 2 跳 (10.0.1.100 和 192.168.1.100) + assert_eq!(analysis.hops, 2); + } + + #[test] + fn test_chain_continuity_check() { + let config = create_test_config(); + let analyzer = ProxyChainAnalyzer::new(config); + + // 测试连续链 + let full_chain = vec![ + IpAddr::from_str("203.0.113.195").unwrap(), + IpAddr::from_str("10.0.1.100").unwrap(), + IpAddr::from_str("192.168.1.100").unwrap(), + ]; + + let trusted_chain = vec![ + IpAddr::from_str("10.0.1.100").unwrap(), + IpAddr::from_str("192.168.1.100").unwrap(), + ]; + + assert!(analyzer.check_chain_continuity(&full_chain, &trusted_chain)); + + // 测试不连续链 + let bad_trusted_chain = vec![IpAddr::from_str("192.168.1.100").unwrap()]; + + assert!(!analyzer.check_chain_continuity(&full_chain, &bad_trusted_chain)); + } + + #[test] + fn test_validate_ip_addresses() { + let config = create_test_config(); + let analyzer = ProxyChainAnalyzer::new(config); + + // 测试有效 IP + let valid_chain = vec![ + IpAddr::from_str("203.0.113.195").unwrap(), + IpAddr::from_str("10.0.1.100").unwrap(), + ]; + + let result = analyzer.validate_ip_addresses(&valid_chain); + assert!(result.is_ok()); + + // 测试未指定地址 + let invalid_chain = vec![IpAddr::from_str("0.0.0.0").unwrap()]; + + let result = analyzer.validate_ip_addresses(&invalid_chain); + assert!(result.is_err()); + + // 测试多播地址 + let multicast_chain = vec![IpAddr::from_str("224.0.0.1").unwrap()]; + + let result = analyzer.validate_ip_addresses(&multicast_chain); + assert!(result.is_err()); + } + + #[test] + fn test_proxy_validator_creation() { + let config = create_test_config(); + let validator = ProxyValidator::new(config, None); + + // 验证器应该成功创建 + assert!(true); // 如果没有 panic,测试通过 + } +}