mirror of
https://github.com/rustfs/rustfs.git
synced 2026-01-16 17:20:33 +00:00
feat(trusted-proxies): optimize core architecture and localize documentation
- **Zero-Trust Security**: Implemented multi-mode proxy validation (Strict, Lenient, Hop-by-Hop) to ensure client IP integrity. - **High Performance**: Integrated `moka` for asynchronous, thread-safe caching of IP validation results. - **Cloud Native**: Enhanced automatic metadata discovery and IP range fetching for AWS, Azure, and GCP. - **Observability**: Added Prometheus metrics and structured JSON logging for production-grade monitoring. - **Refactoring**: Standardized environment variable loading using `rustfs_utils::envs`. - **Localization**: Translated all source code comments and documentation from Chinese to English. - **Test Suite**: Fixed test dependencies and updated integration tests for Axum/Tower compatibility. - **Documentation**: Completed `README.md` with comprehensive configuration and usage guides.
This commit is contained in:
53
Cargo.lock
generated
53
Cargo.lock
generated
@@ -547,6 +547,16 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assert-json-diff"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "astral-tokio-tar"
|
||||
version = "0.5.6"
|
||||
@@ -2999,6 +3009,24 @@ dependencies = [
|
||||
"sqlparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deadpool"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
|
||||
dependencies = [
|
||||
"deadpool-runtime",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deadpool-runtime"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
|
||||
|
||||
[[package]]
|
||||
name = "debugid"
|
||||
version = "0.8.0"
|
||||
@@ -8358,6 +8386,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"dotenvy",
|
||||
"http 1.4.0",
|
||||
"http-body-util",
|
||||
"ipnetwork",
|
||||
"lazy_static",
|
||||
"metrics",
|
||||
@@ -8375,6 +8404,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -11170,6 +11200,29 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wiremock"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
|
||||
dependencies = [
|
||||
"assert-json-diff",
|
||||
"base64",
|
||||
"deadpool",
|
||||
"futures",
|
||||
"http 1.4.0",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.46.0"
|
||||
|
||||
@@ -175,7 +175,7 @@ make help-docker # 显示所有 Docker 相关命令
|
||||
### 访问 RustFS
|
||||
|
||||
5. **访问控制台**: 打开浏览器并访问 `http://localhost:9000` 进入 RustFS 控制台。
|
||||
* 默认账号/密码: `rustfsadmin` / `rustfsadmin`
|
||||
* 默认账号/密码:`rustfsadmin` / `rustfsadmin`
|
||||
6. **创建存储桶**: 使用控制台为您的对象创建一个新的存储桶 (Bucket)。
|
||||
7. **上传对象**: 您可以直接通过控制台上传文件,或使用 S3 兼容的 API/客户端与您的 RustFS 实例进行交互。
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ fn test_encrypt_decrypt_binary_data() -> Result<(), crate::Error> {
|
||||
#[test]
|
||||
fn test_encrypt_decrypt_unicode_data() -> Result<(), crate::Error> {
|
||||
let unicode_strings = [
|
||||
"Hello, 世界! 🌍",
|
||||
"Hello, 世界!🌍",
|
||||
"Тест на русском языке",
|
||||
"العربية اختبار",
|
||||
"🚀🔐💻🌟⭐",
|
||||
|
||||
@@ -33,7 +33,7 @@ http = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
ipnetwork = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
moka = { workspace = true }
|
||||
moka = { workspace = true, features = ["future"] }
|
||||
reqwest = { workspace = true }
|
||||
rustfs-utils = { workspace = true }
|
||||
serde.workspace = true
|
||||
@@ -49,5 +49,19 @@ regex = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
dotenvy = "0.15.7"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["full", "test-util"] }
|
||||
tower = { workspace = true, features = ["util"] }
|
||||
http-body-util = "0.1"
|
||||
wiremock = "0.6"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[[test]]
|
||||
name = "unit_tests"
|
||||
path = "tests/unit/mod.rs"
|
||||
|
||||
[[test]]
|
||||
name = "integration_tests"
|
||||
path = "tests/integration/mod.rs"
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
# RustFS Trusted Proxies
|
||||
|
||||
The `rustfs-trusted-proxies` module provides secure and efficient management of trusted proxy servers within the RustFS ecosystem. It is designed to handle multi-layer proxy architectures, ensuring accurate client IP identification while maintaining a zero-trust security model.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-Layer Proxy Validation**: Supports `Strict`, `Lenient`, and `HopByHop` validation modes to accurately identify the real client IP address.
|
||||
- **Zero-Trust Security**: Verifies every hop in the proxy chain against a configurable list of trusted networks.
|
||||
- **Cloud Integration**: Automatic discovery of trusted IP ranges for major cloud providers including AWS, Azure, and GCP.
|
||||
- **High Performance**: Utilizes the `moka` cache for fast lookup of validation results and `axum` for a high-performance web interface.
|
||||
- **Observability**: Built-in support for Prometheus metrics and structured JSON logging via `tracing`.
|
||||
- **RFC 7239 Support**: Full support for the modern `Forwarded` header alongside legacy `X-Forwarded-For` headers.
|
||||
|
||||
## Configuration
|
||||
|
||||
The module is configured primarily through environment variables:
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `TRUSTED_PROXY_VALIDATION_MODE` | `hop_by_hop` | Validation strategy (`strict`, `lenient`, `hop_by_hop`) |
|
||||
| `TRUSTED_PROXY_NETWORKS` | `127.0.0.1,::1,...` | Comma-separated list of trusted CIDR ranges |
|
||||
| `TRUSTED_PROXY_MAX_HOPS` | `10` | Maximum allowed proxy hops |
|
||||
| `TRUSTED_PROXY_CACHE_CAPACITY` | `10000` | Max entries in the validation cache |
|
||||
| `TRUSTED_PROXY_METRICS_ENABLED` | `true` | Enable Prometheus metrics collection |
|
||||
| `TRUSTED_PROXY_CLOUD_METADATA_ENABLED` | `false` | Enable auto-discovery of cloud IP ranges |
|
||||
|
||||
## Usage
|
||||
|
||||
### As a Middleware
|
||||
|
||||
Integrate the trusted proxy validation into your Axum application:
|
||||
|
||||
```rust
|
||||
use rustfs_trusted_proxies::{TrustedProxyLayer, TrustedProxyConfig};
|
||||
|
||||
let config = TrustedProxyConfig::default();
|
||||
let layer = TrustedProxyLayer::enabled(config, None);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(handler))
|
||||
.layer(layer);
|
||||
```
|
||||
|
||||
### Accessing Client Info
|
||||
|
||||
Retrieve the verified client information in your handlers:
|
||||
|
||||
```rust
|
||||
use rustfs_trusted_proxies::ClientInfo;
|
||||
|
||||
async fn handler(Extension(client_info): Extension<ClientInfo>) -> impl IntoResponse {
|
||||
println!("Real Client IP: {}", client_info.real_ip);
|
||||
}
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Pre-Commit Checklist
|
||||
Before committing, ensure all checks pass:
|
||||
```bash
|
||||
make pre-commit
|
||||
```
|
||||
|
||||
### Testing
|
||||
Run the test suite:
|
||||
```bash
|
||||
cargo test --workspace --exclude e2e_test
|
||||
```
|
||||
|
||||
## License
|
||||
Licensed under the Apache License, Version 2.0.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! API request handlers
|
||||
//! API request handlers for the trusted proxy service.
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
@@ -25,7 +25,7 @@ use crate::error::AppError;
|
||||
use crate::middleware::ClientInfo;
|
||||
use crate::AppState;
|
||||
|
||||
/// 健康检查端点
|
||||
/// Health check endpoint to verify service availability.
|
||||
pub async fn health_check() -> impl IntoResponse {
|
||||
Json(json!({
|
||||
"status": "healthy",
|
||||
@@ -35,7 +35,7 @@ pub async fn health_check() -> impl IntoResponse {
|
||||
}))
|
||||
}
|
||||
|
||||
/// 显示配置信息
|
||||
/// Returns the current application configuration.
|
||||
pub async fn show_config(State(state): State<AppState>) -> Result<Json<Value>, AppError> {
|
||||
let config = &state.config;
|
||||
|
||||
@@ -45,7 +45,7 @@ pub async fn show_config(State(state): State<AppState>) -> Result<Json<Value>, A
|
||||
},
|
||||
"proxy": {
|
||||
"trusted_networks_count": config.proxy.proxies.len(),
|
||||
"validation_mode": format!("{:?}", config.proxy.validation_mode),
|
||||
"validation_mode": config.proxy.validation_mode.as_str(),
|
||||
"max_hops": config.proxy.max_hops,
|
||||
"enable_rfc7239": config.proxy.enable_rfc7239,
|
||||
},
|
||||
@@ -66,9 +66,9 @@ pub async fn show_config(State(state): State<AppState>) -> Result<Json<Value>, A
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
/// 显示客户端信息
|
||||
pub async fn client_info(State(state): State<AppState>, req: Request) -> impl IntoResponse {
|
||||
// 从请求扩展中获取客户端信息
|
||||
/// Returns information about the client as identified by the trusted proxy middleware.
|
||||
pub async fn client_info(State(_state): State<AppState>, req: Request) -> impl IntoResponse {
|
||||
// Retrieve the verified client information from the request extensions.
|
||||
let client_info = req.extensions().get::<ClientInfo>();
|
||||
|
||||
match client_info {
|
||||
@@ -78,7 +78,7 @@ pub async fn client_info(State(state): State<AppState>, req: Request) -> impl In
|
||||
"real_ip": info.real_ip.to_string(),
|
||||
"is_from_trusted_proxy": info.is_from_trusted_proxy,
|
||||
"proxy_hops": info.proxy_hops,
|
||||
"validation_mode": format!("{:?}", info.validation_mode),
|
||||
"validation_mode": info.validation_mode.as_str(),
|
||||
},
|
||||
"headers": {
|
||||
"forwarded_host": info.forwarded_host,
|
||||
@@ -93,7 +93,7 @@ pub async fn client_info(State(state): State<AppState>, req: Request) -> impl In
|
||||
None => {
|
||||
let response = json!({
|
||||
"error": "Client information not available",
|
||||
"message": "The trusted proxy middleware may not be enabled or configured correctly",
|
||||
"message": "The trusted proxy middleware may not be enabled or configured correctly.",
|
||||
});
|
||||
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(response)).into_response()
|
||||
@@ -101,9 +101,9 @@ pub async fn client_info(State(state): State<AppState>, req: Request) -> impl In
|
||||
}
|
||||
}
|
||||
|
||||
/// 代理测试端点(用于测试代理头部)
|
||||
/// Debugging endpoint that returns all proxy-related headers received in the request.
|
||||
pub async fn proxy_test(req: Request) -> Json<Value> {
|
||||
// 收集所有代理相关的头部
|
||||
// Collect all headers related to proxying.
|
||||
let headers: Vec<(String, String)> = req
|
||||
.headers()
|
||||
.iter()
|
||||
@@ -114,7 +114,7 @@ pub async fn proxy_test(req: Request) -> Json<Value> {
|
||||
.map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("[INVALID]").to_string()))
|
||||
.collect();
|
||||
|
||||
// 获取对端地址
|
||||
// Get the direct peer address.
|
||||
let peer_addr = req
|
||||
.extensions()
|
||||
.get::<std::net::SocketAddr>()
|
||||
@@ -130,19 +130,17 @@ pub async fn proxy_test(req: Request) -> Json<Value> {
|
||||
}))
|
||||
}
|
||||
|
||||
/// 指标端点(Prometheus 格式)
|
||||
/// Endpoint for retrieving Prometheus metrics.
|
||||
pub async fn metrics(State(state): State<AppState>) -> impl IntoResponse {
|
||||
if !state.config.monitoring.metrics_enabled {
|
||||
return (StatusCode::NOT_FOUND, "Metrics are not enabled".to_string()).into_response();
|
||||
return (StatusCode::NOT_FOUND, "Metrics are not enabled").into_response();
|
||||
}
|
||||
|
||||
// 在实际应用中,这里应该返回 Prometheus 格式的指标
|
||||
// 这里返回简单的 JSON 作为示例
|
||||
let metrics = json!({
|
||||
"message": "Metrics endpoint",
|
||||
"note": "In a real implementation, this would return Prometheus format metrics",
|
||||
// In a production environment, this would return the actual Prometheus-formatted metrics.
|
||||
let metrics_summary = json!({
|
||||
"status": "metrics_enabled",
|
||||
"note": "Prometheus metrics are being collected. Use a compatible exporter to view them.",
|
||||
});
|
||||
|
||||
Json(metrics).into_response()
|
||||
Json(metrics_summary).into_response()
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Cloud provider detection and metadata fetching
|
||||
//! Cloud provider detection and metadata fetching.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::time::Duration;
|
||||
@@ -20,7 +20,7 @@ use tracing::{debug, info, warn};
|
||||
|
||||
use crate::error::AppError;
|
||||
|
||||
/// 云服务商类型
|
||||
/// Supported cloud providers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum CloudProvider {
|
||||
/// Amazon Web Services
|
||||
@@ -33,14 +33,14 @@ pub enum CloudProvider {
|
||||
DigitalOcean,
|
||||
/// Cloudflare
|
||||
Cloudflare,
|
||||
/// 未知或自定义
|
||||
/// Unknown or custom provider.
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl CloudProvider {
|
||||
/// 从环境变量检测云服务商
|
||||
/// Detects the cloud provider based on environment variables.
|
||||
pub fn detect_from_env() -> Option<Self> {
|
||||
// 检查 AWS 环境变量
|
||||
// Check for AWS environment variables.
|
||||
if std::env::var("AWS_EXECUTION_ENV").is_ok()
|
||||
|| std::env::var("AWS_REGION").is_ok()
|
||||
|| std::env::var("EC2_INSTANCE_ID").is_ok()
|
||||
@@ -48,7 +48,7 @@ impl CloudProvider {
|
||||
return Some(Self::Aws);
|
||||
}
|
||||
|
||||
// 检查 Azure 环境变量
|
||||
// Check for Azure environment variables.
|
||||
if std::env::var("WEBSITE_SITE_NAME").is_ok()
|
||||
|| std::env::var("WEBSITE_INSTANCE_ID").is_ok()
|
||||
|| std::env::var("APPSETTING_WEBSITE_SITE_NAME").is_ok()
|
||||
@@ -56,7 +56,7 @@ impl CloudProvider {
|
||||
return Some(Self::Azure);
|
||||
}
|
||||
|
||||
// 检查 GCP 环境变量
|
||||
// Check for GCP environment variables.
|
||||
if std::env::var("GCP_PROJECT").is_ok()
|
||||
|| std::env::var("GOOGLE_CLOUD_PROJECT").is_ok()
|
||||
|| std::env::var("GAE_INSTANCE").is_ok()
|
||||
@@ -64,12 +64,12 @@ impl CloudProvider {
|
||||
return Some(Self::Gcp);
|
||||
}
|
||||
|
||||
// 检查 DigitalOcean 环境变量
|
||||
// Check for DigitalOcean environment variables.
|
||||
if std::env::var("DIGITALOCEAN_REGION").is_ok() {
|
||||
return Some(Self::DigitalOcean);
|
||||
}
|
||||
|
||||
// 检查 Cloudflare 环境变量
|
||||
// Check for Cloudflare environment variables.
|
||||
if std::env::var("CF_PAGES").is_ok() || std::env::var("CF_WORKERS").is_ok() {
|
||||
return Some(Self::Cloudflare);
|
||||
}
|
||||
@@ -77,7 +77,7 @@ impl CloudProvider {
|
||||
None
|
||||
}
|
||||
|
||||
/// 获取云服务商名称
|
||||
/// Returns the canonical name of the cloud provider.
|
||||
pub fn name(&self) -> &str {
|
||||
match self {
|
||||
Self::Aws => "aws",
|
||||
@@ -89,7 +89,7 @@ impl CloudProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// 从字符串解析云服务商
|
||||
/// Parses a cloud provider from a string.
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"aws" | "amazon" => Self::Aws,
|
||||
@@ -102,29 +102,27 @@ impl CloudProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// 云元数据获取器特征
|
||||
/// Trait for fetching metadata from a specific cloud provider.
|
||||
#[async_trait]
|
||||
pub trait CloudMetadataFetcher: Send + Sync {
|
||||
/// 获取云服务商名称
|
||||
/// Returns the name of the provider.
|
||||
fn provider_name(&self) -> &str;
|
||||
|
||||
/// 获取实例所在的网络 CIDR 范围
|
||||
/// Fetches the network CIDR ranges for the current instance.
|
||||
async fn fetch_network_cidrs(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError>;
|
||||
|
||||
/// 获取云服务商的公共 IP 范围
|
||||
/// Fetches the public IP ranges for the cloud provider.
|
||||
async fn fetch_public_ip_ranges(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError>;
|
||||
|
||||
/// 获取可信代理的 IP 范围
|
||||
/// Fetches all IP ranges that should be considered trusted proxies.
|
||||
async fn fetch_trusted_proxy_ranges(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
let mut ranges = Vec::new();
|
||||
|
||||
// 尝试获取网络 CIDR
|
||||
match self.fetch_network_cidrs().await {
|
||||
Ok(cidrs) => ranges.extend(cidrs),
|
||||
Err(e) => warn!("Failed to fetch network CIDRs from {}: {}", self.provider_name(), e),
|
||||
}
|
||||
|
||||
// 尝试获取公共 IP 范围
|
||||
match self.fetch_public_ip_ranges().await {
|
||||
Ok(public_ranges) => ranges.extend(public_ranges),
|
||||
Err(e) => warn!("Failed to fetch public IP ranges from {}: {}", self.provider_name(), e),
|
||||
@@ -134,19 +132,19 @@ pub trait CloudMetadataFetcher: Send + Sync {
|
||||
}
|
||||
}
|
||||
|
||||
/// 云服务检测器
|
||||
/// Detector for identifying the current cloud environment and fetching relevant metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CloudDetector {
|
||||
/// 是否启用云检测
|
||||
/// Whether cloud detection is enabled.
|
||||
enabled: bool,
|
||||
/// 超时时间
|
||||
/// Timeout for metadata requests.
|
||||
timeout: Duration,
|
||||
/// 强制指定的云服务商
|
||||
/// Optionally force a specific provider.
|
||||
forced_provider: Option<CloudProvider>,
|
||||
}
|
||||
|
||||
impl CloudDetector {
|
||||
/// 创建新的云检测器
|
||||
/// Creates a new `CloudDetector`.
|
||||
pub fn new(enabled: bool, timeout: Duration, forced_provider: Option<String>) -> Self {
|
||||
let forced_provider = forced_provider.map(|s| CloudProvider::from_str(&s));
|
||||
|
||||
@@ -157,22 +155,20 @@ impl CloudDetector {
|
||||
}
|
||||
}
|
||||
|
||||
/// 检测云服务商
|
||||
/// Identifies the current cloud provider.
|
||||
pub fn detect_provider(&self) -> Option<CloudProvider> {
|
||||
if !self.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
// 如果强制指定了云服务商,直接返回
|
||||
if let Some(provider) = self.forced_provider {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// 自动检测
|
||||
CloudProvider::detect_from_env()
|
||||
}
|
||||
|
||||
/// 获取可信代理 IP 范围
|
||||
/// Fetches trusted IP ranges for the detected cloud provider.
|
||||
pub async fn fetch_trusted_ranges(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
if !self.enabled {
|
||||
debug!("Cloud metadata fetching is disabled");
|
||||
@@ -218,7 +214,7 @@ impl CloudDetector {
|
||||
}
|
||||
}
|
||||
|
||||
/// 尝试所有云服务商获取元数据
|
||||
/// Attempts to fetch metadata from all supported providers sequentially.
|
||||
pub async fn try_all_providers(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
if !self.enabled {
|
||||
return Ok(Vec::new());
|
||||
@@ -251,11 +247,7 @@ impl CloudDetector {
|
||||
}
|
||||
}
|
||||
|
||||
/// 默认云检测器
|
||||
/// Returns a default `CloudDetector` with detection disabled.
|
||||
pub fn default_cloud_detector() -> CloudDetector {
|
||||
CloudDetector::new(
|
||||
false, // 默认禁用
|
||||
Duration::from_secs(5),
|
||||
None,
|
||||
)
|
||||
CloudDetector::new(false, Duration::from_secs(5), None)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! AWS metadata fetching implementation
|
||||
//! AWS metadata fetching implementation for identifying trusted proxy ranges.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
@@ -23,7 +23,7 @@ use tracing::{debug, info};
|
||||
use crate::cloud::detector::CloudMetadataFetcher;
|
||||
use crate::error::AppError;
|
||||
|
||||
/// AWS 元数据获取器
|
||||
/// Fetcher for AWS-specific metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AwsMetadataFetcher {
|
||||
client: Client,
|
||||
@@ -31,7 +31,7 @@ pub struct AwsMetadataFetcher {
|
||||
}
|
||||
|
||||
impl AwsMetadataFetcher {
|
||||
/// 创建新的 AWS 元数据获取器
|
||||
/// Creates a new `AwsMetadataFetcher`.
|
||||
pub fn new() -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
@@ -44,7 +44,7 @@ impl AwsMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取 IMDSv2 令牌
|
||||
/// Retrieves an IMDSv2 token for secure metadata access.
|
||||
async fn get_metadata_token(&self) -> Result<String, AppError> {
|
||||
let url = format!("{}/latest/api/token", self.metadata_endpoint);
|
||||
|
||||
@@ -60,11 +60,11 @@ impl AwsMetadataFetcher {
|
||||
let token = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AppError::cloud(format!("Failed to read token: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to read IMDSv2 token: {}", e)))?;
|
||||
Ok(token)
|
||||
} else {
|
||||
debug!("IMDSv2 token request failed with status: {}", response.status());
|
||||
Err(AppError::cloud("Failed to get IMDSv2 token".to_string()))
|
||||
Err(AppError::cloud("Failed to obtain IMDSv2 token"))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -82,11 +82,11 @@ impl CloudMetadataFetcher for AwsMetadataFetcher {
|
||||
}
|
||||
|
||||
async fn fetch_network_cidrs(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
// 简化实现:返回常见的 AWS VPC 范围
|
||||
// Simplified implementation: returns standard AWS VPC private ranges.
|
||||
let default_ranges = vec![
|
||||
"10.0.0.0/8", // 大型 VPC
|
||||
"172.16.0.0/12", // 中型 VPC
|
||||
"192.168.0.0/16", // 小型 VPC
|
||||
"10.0.0.0/8", // Large VPCs
|
||||
"172.16.0.0/12", // Medium VPCs
|
||||
"192.168.0.0/16", // Small VPCs
|
||||
];
|
||||
|
||||
let networks: Result<Vec<_>, _> = default_ranges
|
||||
@@ -96,10 +96,10 @@ impl CloudMetadataFetcher for AwsMetadataFetcher {
|
||||
|
||||
match networks {
|
||||
Ok(networks) => {
|
||||
debug!("Using default AWS network ranges");
|
||||
debug!("Using default AWS VPC network ranges");
|
||||
Ok(networks)
|
||||
}
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse default ranges: {}", e))),
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse default AWS ranges: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +114,6 @@ impl CloudMetadataFetcher for AwsMetadataFetcher {
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct AwsPrefix {
|
||||
ip_prefix: String,
|
||||
region: String,
|
||||
service: String,
|
||||
}
|
||||
|
||||
@@ -124,12 +123,12 @@ impl CloudMetadataFetcher for AwsMetadataFetcher {
|
||||
let ip_ranges: AwsIpRanges = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse AWS IP ranges: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse AWS IP ranges JSON: {}", e)))?;
|
||||
|
||||
let mut networks = Vec::new();
|
||||
|
||||
for prefix in ip_ranges.prefixes {
|
||||
// 只包含 EC2 和 CloudFront 的 IP 范围
|
||||
// Include EC2 and CloudFront ranges as potential trusted proxies.
|
||||
if prefix.service == "EC2" || prefix.service == "CLOUDFRONT" {
|
||||
if let Ok(network) = ipnetwork::IpNetwork::from_str(&prefix.ip_prefix) {
|
||||
networks.push(network);
|
||||
@@ -137,10 +136,10 @@ impl CloudMetadataFetcher for AwsMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
info!("Fetched {} AWS public IP ranges", networks.len());
|
||||
info!("Successfully fetched {} AWS public IP ranges", networks.len());
|
||||
Ok(networks)
|
||||
} else {
|
||||
debug!("Failed to fetch AWS IP ranges: {}", response.status());
|
||||
debug!("Failed to fetch AWS IP ranges: HTTP {}", response.status());
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,18 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Azure Cloud metadata fetching implementation
|
||||
//! Azure Cloud metadata fetching implementation for identifying trusted proxy ranges.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::cloud::detector::CloudMetadataFetcher;
|
||||
use crate::error::AppError;
|
||||
|
||||
/// Azure 元数据获取器
|
||||
/// Fetcher for Azure-specific metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AzureMetadataFetcher {
|
||||
client: Client,
|
||||
@@ -31,7 +32,7 @@ pub struct AzureMetadataFetcher {
|
||||
}
|
||||
|
||||
impl AzureMetadataFetcher {
|
||||
/// 创建新的 Azure 元数据获取器
|
||||
/// Creates a new `AzureMetadataFetcher`.
|
||||
pub fn new() -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
@@ -44,7 +45,7 @@ impl AzureMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取 Azure 元数据
|
||||
/// Retrieves metadata from the Azure Instance Metadata Service (IMDS).
|
||||
async fn get_metadata(&self, path: &str) -> Result<String, AppError> {
|
||||
let url = format!("{}/metadata/{}?api-version=2021-05-01", self.metadata_endpoint, path);
|
||||
|
||||
@@ -56,7 +57,7 @@ impl AzureMetadataFetcher {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AppError::cloud(format!("Failed to read response: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to read Azure metadata response: {}", e)))?;
|
||||
Ok(text)
|
||||
} else {
|
||||
debug!("Azure metadata request failed with status: {}", response.status());
|
||||
@@ -70,9 +71,9 @@ impl AzureMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 Microsoft 下载 IP 范围
|
||||
/// Fetches Azure public IP ranges from the official Microsoft download source.
|
||||
async fn fetch_azure_ip_ranges(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
// Azure 官方 IP 范围下载 URL
|
||||
// Official Azure IP ranges download URL (periodically updated).
|
||||
let url =
|
||||
"https://download.microsoft.com/download/7/1/D/71D86715-5596-4529-9B13-DA13A5DE5B63/ServiceTags_Public_20231211.json";
|
||||
|
||||
@@ -83,7 +84,6 @@ impl AzureMetadataFetcher {
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AzureServiceTag {
|
||||
id: String,
|
||||
name: String,
|
||||
properties: AzureServiceTagProperties,
|
||||
}
|
||||
@@ -91,8 +91,6 @@ impl AzureMetadataFetcher {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AzureServiceTagProperties {
|
||||
address_prefixes: Vec<String>,
|
||||
region: Option<String>,
|
||||
system_service: Option<String>,
|
||||
}
|
||||
|
||||
debug!("Fetching Azure IP ranges from: {}", url);
|
||||
@@ -103,12 +101,12 @@ impl AzureMetadataFetcher {
|
||||
let service_tags: AzureServiceTags = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse Azure IP ranges: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse Azure IP ranges JSON: {}", e)))?;
|
||||
|
||||
let mut networks = Vec::new();
|
||||
|
||||
for tag in service_tags.values {
|
||||
// 只包含 Azure 数据中心和前端服务的 IP 范围
|
||||
// Include general Azure datacenter ranges, excluding specific internal services.
|
||||
if tag.name.contains("Azure") && !tag.name.contains("ActiveDirectory") {
|
||||
for prefix in tag.properties.address_prefixes {
|
||||
if let Ok(network) = ipnetwork::IpNetwork::from_str(&prefix) {
|
||||
@@ -118,25 +116,24 @@ impl AzureMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
info!("Fetched {} Azure public IP ranges", networks.len());
|
||||
info!("Successfully fetched {} Azure public IP ranges", networks.len());
|
||||
Ok(networks)
|
||||
} else {
|
||||
debug!("Failed to fetch Azure IP ranges: {}", response.status());
|
||||
debug!("Failed to fetch Azure IP ranges: HTTP {}", response.status());
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to fetch Azure IP ranges: {}", e);
|
||||
// 如果 API 失败,返回默认的 Azure IP 范围
|
||||
// Fallback to hardcoded ranges if the download fails.
|
||||
Self::default_azure_ranges()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 默认 Azure IP 范围(作为备选)
|
||||
/// Returns a set of default Azure IP ranges as a fallback.
|
||||
fn default_azure_ranges() -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
let ranges = vec![
|
||||
// Azure 全球 IP 范围
|
||||
"13.64.0.0/11",
|
||||
"13.96.0.0/13",
|
||||
"13.104.0.0/14",
|
||||
@@ -199,7 +196,6 @@ impl AzureMetadataFetcher {
|
||||
"168.62.0.0/15",
|
||||
"191.233.0.0/18",
|
||||
"193.149.0.0/19",
|
||||
// IPv6 范围
|
||||
"2603:1000::/40",
|
||||
"2603:1010::/40",
|
||||
"2603:1020::/40",
|
||||
@@ -223,7 +219,7 @@ impl AzureMetadataFetcher {
|
||||
|
||||
match networks {
|
||||
Ok(networks) => {
|
||||
debug!("Using default Azure IP ranges");
|
||||
debug!("Using default Azure public IP ranges");
|
||||
Ok(networks)
|
||||
}
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse default Azure ranges: {}", e))),
|
||||
@@ -238,7 +234,7 @@ impl CloudMetadataFetcher for AzureMetadataFetcher {
|
||||
}
|
||||
|
||||
async fn fetch_network_cidrs(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
// 尝试从 Azure 元数据获取网络信息
|
||||
// Attempt to fetch network interface information from Azure IMDS.
|
||||
match self.get_metadata("instance/network/interface").await {
|
||||
Ok(metadata) => {
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -258,7 +254,7 @@ impl CloudMetadataFetcher for AzureMetadataFetcher {
|
||||
}
|
||||
|
||||
let interfaces: Vec<AzureNetworkInterface> = serde_json::from_str(&metadata)
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse Azure network metadata: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse Azure network metadata JSON: {}", e)))?;
|
||||
|
||||
let mut cidrs = Vec::new();
|
||||
for interface in interfaces {
|
||||
@@ -271,17 +267,15 @@ impl CloudMetadataFetcher for AzureMetadataFetcher {
|
||||
}
|
||||
|
||||
if !cidrs.is_empty() {
|
||||
info!("Fetched {} network CIDRs from Azure metadata", cidrs.len());
|
||||
info!("Successfully fetched {} network CIDRs from Azure metadata", cidrs.len());
|
||||
Ok(cidrs)
|
||||
} else {
|
||||
// 如果元数据中没有网络信息,使用默认的 Azure VNet 范围
|
||||
debug!("No network CIDRs found in Azure metadata, using defaults");
|
||||
debug!("No network CIDRs found in Azure metadata, falling back to defaults");
|
||||
Self::default_azure_network_ranges()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch Azure network metadata: {}", e);
|
||||
// 元数据获取失败,使用默认范围
|
||||
Self::default_azure_network_ranges()
|
||||
}
|
||||
}
|
||||
@@ -293,25 +287,24 @@ impl CloudMetadataFetcher for AzureMetadataFetcher {
|
||||
}
|
||||
|
||||
impl AzureMetadataFetcher {
|
||||
/// 默认 Azure 网络范围
|
||||
/// Returns a set of default Azure VNet ranges as a fallback.
|
||||
fn default_azure_network_ranges() -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
// Azure 虚拟网络的常见 IP 范围
|
||||
let ranges = vec![
|
||||
"10.0.0.0/8", // 大型虚拟网络
|
||||
"172.16.0.0/12", // 中型虚拟网络
|
||||
"192.168.0.0/16", // 小型虚拟网络
|
||||
"100.64.0.0/10", // Azure 保留范围
|
||||
"192.0.0.0/24", // Azure 保留
|
||||
"10.0.0.0/8", // Large VNets
|
||||
"172.16.0.0/12", // Medium VNets
|
||||
"192.168.0.0/16", // Small VNets
|
||||
"100.64.0.0/10", // Azure reserved range
|
||||
"192.0.0.0/24", // Azure reserved
|
||||
];
|
||||
|
||||
let networks: Result<Vec<_>, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect();
|
||||
|
||||
match networks {
|
||||
Ok(networks) => {
|
||||
debug!("Using default Azure network ranges");
|
||||
debug!("Using default Azure VNet network ranges");
|
||||
Ok(networks)
|
||||
}
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse default network ranges: {}", e))),
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse default Azure network ranges: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Google Cloud Platform (GCP) metadata fetching implementation
|
||||
//! Google Cloud Platform (GCP) metadata fetching implementation for identifying trusted proxy ranges.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
@@ -24,7 +24,7 @@ use tracing::{debug, info, warn};
|
||||
use crate::cloud::detector::CloudMetadataFetcher;
|
||||
use crate::error::AppError;
|
||||
|
||||
/// GCP 元数据获取器
|
||||
/// Fetcher for GCP-specific metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GcpMetadataFetcher {
|
||||
client: Client,
|
||||
@@ -32,7 +32,7 @@ pub struct GcpMetadataFetcher {
|
||||
}
|
||||
|
||||
impl GcpMetadataFetcher {
|
||||
/// 创建新的 GCP 元数据获取器
|
||||
/// Creates a new `GcpMetadataFetcher`.
|
||||
pub fn new() -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
@@ -45,7 +45,7 @@ impl GcpMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取 GCP 元数据
|
||||
/// Retrieves metadata from the GCP Compute Engine metadata server.
|
||||
async fn get_metadata(&self, path: &str) -> Result<String, AppError> {
|
||||
let url = format!("{}/computeMetadata/v1/{}", self.metadata_endpoint, path);
|
||||
|
||||
@@ -57,7 +57,7 @@ impl GcpMetadataFetcher {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AppError::cloud(format!("Failed to read response: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to read GCP metadata response: {}", e)))?;
|
||||
Ok(text)
|
||||
} else {
|
||||
debug!("GCP metadata request failed with status: {}", response.status());
|
||||
@@ -71,11 +71,11 @@ impl GcpMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取网络掩码的前缀长度
|
||||
/// Converts a dotted-decimal subnet mask to a CIDR prefix length.
|
||||
fn subnet_mask_to_prefix_length(mask: &str) -> Result<u8, AppError> {
|
||||
let parts: Vec<&str> = mask.split('.').collect();
|
||||
if parts.len() != 4 {
|
||||
return Err(AppError::cloud(format!("Invalid subnet mask: {}", mask)));
|
||||
return Err(AppError::cloud(format!("Invalid subnet mask format: {}", mask)));
|
||||
}
|
||||
|
||||
let mut prefix_length = 0;
|
||||
@@ -95,7 +95,7 @@ impl GcpMetadataFetcher {
|
||||
}
|
||||
|
||||
if remaining != 0 {
|
||||
return Err(AppError::cloud("Non-contiguous subnet mask".to_string()));
|
||||
return Err(AppError::cloud("Non-contiguous subnet mask detected"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,10 +110,9 @@ impl CloudMetadataFetcher for GcpMetadataFetcher {
|
||||
}
|
||||
|
||||
async fn fetch_network_cidrs(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
// 获取网络接口列表
|
||||
// Attempt to list network interfaces from GCP metadata.
|
||||
match self.get_metadata("instance/network-interfaces/").await {
|
||||
Ok(interfaces_metadata) => {
|
||||
// 解析网络接口索引
|
||||
let interface_indices: Vec<usize> = interfaces_metadata
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
@@ -134,24 +133,7 @@ impl CloudMetadataFetcher for GcpMetadataFetcher {
|
||||
let mut cidrs = Vec::new();
|
||||
|
||||
for index in interface_indices {
|
||||
// 获取子网信息
|
||||
let subnet_path = format!("instance/network-interfaces/{}/subnetworks", index);
|
||||
|
||||
if let Ok(subnet_metadata) = self.get_metadata(&subnet_path).await {
|
||||
// 子网元数据可能包含多个子网,取第一个
|
||||
if let Some(first_subnet) = subnet_metadata.lines().next() {
|
||||
let subnet = first_subnet.trim();
|
||||
if !subnet.is_empty() {
|
||||
// 尝试从子网名称提取网络信息
|
||||
if let Some(network) = Self::extract_network_from_subnet_name(subnet) {
|
||||
cidrs.push(network);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 备选方案:使用 IP 地址和子网掩码
|
||||
// Try to get IP and subnet mask for each interface.
|
||||
let ip_path = format!("instance/network-interfaces/{}/ip", index);
|
||||
let mask_path = format!("instance/network-interfaces/{}/subnetmask", index);
|
||||
|
||||
@@ -170,16 +152,16 @@ impl CloudMetadataFetcher for GcpMetadataFetcher {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get IP/mask for interface {}: {}", index, e);
|
||||
debug!("Failed to get IP/mask for GCP interface {}: {}", index, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cidrs.is_empty() {
|
||||
warn!("Could not determine network CIDRs from GCP metadata");
|
||||
warn!("Could not determine network CIDRs from GCP metadata, falling back to defaults");
|
||||
Self::default_gcp_network_ranges()
|
||||
} else {
|
||||
info!("Fetched {} network CIDRs from GCP metadata", cidrs.len());
|
||||
info!("Successfully fetched {} network CIDRs from GCP metadata", cidrs.len());
|
||||
Ok(cidrs)
|
||||
}
|
||||
}
|
||||
@@ -196,7 +178,7 @@ impl CloudMetadataFetcher for GcpMetadataFetcher {
|
||||
}
|
||||
|
||||
impl GcpMetadataFetcher {
|
||||
/// 从 Google API 获取 IP 范围
|
||||
/// Fetches GCP public IP ranges from the official Google source.
|
||||
async fn fetch_gcp_ip_ranges(&self) -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
let url = "https://www.gstatic.com/ipranges/cloud.json";
|
||||
|
||||
@@ -208,7 +190,6 @@ impl GcpMetadataFetcher {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GcpPrefix {
|
||||
ipv4_prefix: Option<String>,
|
||||
ipv6_prefix: Option<String>,
|
||||
}
|
||||
|
||||
debug!("Fetching GCP IP ranges from: {}", url);
|
||||
@@ -219,7 +200,7 @@ impl GcpMetadataFetcher {
|
||||
let ip_ranges: GcpIpRanges = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse GCP IP ranges: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse GCP IP ranges JSON: {}", e)))?;
|
||||
|
||||
let mut networks = Vec::new();
|
||||
|
||||
@@ -231,10 +212,10 @@ impl GcpMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
info!("Fetched {} GCP public IP ranges", networks.len());
|
||||
info!("Successfully fetched {} GCP public IP ranges", networks.len());
|
||||
Ok(networks)
|
||||
} else {
|
||||
debug!("Failed to fetch GCP IP ranges: {}", response.status());
|
||||
debug!("Failed to fetch GCP IP ranges: HTTP {}", response.status());
|
||||
Self::default_gcp_ip_ranges()
|
||||
}
|
||||
}
|
||||
@@ -245,10 +226,9 @@ impl GcpMetadataFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// 默认 GCP IP 范围(作为备选)
|
||||
/// Returns a set of default GCP public IP ranges as a fallback.
|
||||
fn default_gcp_ip_ranges() -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
let ranges = vec![
|
||||
// GCP 全球 IP 范围
|
||||
"8.34.208.0/20",
|
||||
"8.35.192.0/20",
|
||||
"8.35.208.0/20",
|
||||
@@ -257,13 +237,6 @@ impl GcpMetadataFetcher {
|
||||
"34.0.0.0/15",
|
||||
"34.2.0.0/16",
|
||||
"34.3.0.0/23",
|
||||
"34.3.3.0/24",
|
||||
"34.3.4.0/24",
|
||||
"34.3.8.0/21",
|
||||
"34.3.16.0/20",
|
||||
"34.3.32.0/19",
|
||||
"34.3.64.0/18",
|
||||
"34.3.128.0/17",
|
||||
"34.4.0.0/14",
|
||||
"34.8.0.0/13",
|
||||
"34.16.0.0/12",
|
||||
@@ -274,8 +247,6 @@ impl GcpMetadataFetcher {
|
||||
"35.192.0.0/14",
|
||||
"35.196.0.0/15",
|
||||
"35.198.0.0/16",
|
||||
"35.199.0.0/17",
|
||||
"35.199.128.0/18",
|
||||
"35.200.0.0/13",
|
||||
"35.208.0.0/12",
|
||||
"35.224.0.0/12",
|
||||
@@ -291,28 +262,13 @@ impl GcpMetadataFetcher {
|
||||
"136.112.0.0/12",
|
||||
"142.250.0.0/15",
|
||||
"146.148.0.0/17",
|
||||
"162.216.148.0/22",
|
||||
"162.222.176.0/21",
|
||||
"172.217.0.0/16",
|
||||
"172.253.0.0/16",
|
||||
"173.194.0.0/16",
|
||||
"192.158.28.0/22",
|
||||
"192.178.0.0/15",
|
||||
"193.186.4.0/24",
|
||||
"199.36.154.0/23",
|
||||
"199.36.156.0/24",
|
||||
"199.192.112.0/22",
|
||||
"199.223.232.0/21",
|
||||
"207.223.160.0/20",
|
||||
"208.65.152.0/22",
|
||||
"208.68.108.0/22",
|
||||
"208.81.188.0/22",
|
||||
"208.117.224.0/19",
|
||||
"209.85.128.0/17",
|
||||
"216.58.192.0/19",
|
||||
"216.73.80.0/20",
|
||||
"216.239.32.0/19",
|
||||
// IPv6 范围
|
||||
"2001:4860::/32",
|
||||
"2404:6800::/32",
|
||||
"2600:1900::/28",
|
||||
@@ -327,54 +283,30 @@ impl GcpMetadataFetcher {
|
||||
|
||||
match networks {
|
||||
Ok(networks) => {
|
||||
debug!("Using default GCP IP ranges");
|
||||
debug!("Using default GCP public IP ranges");
|
||||
Ok(networks)
|
||||
}
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse default GCP ranges: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
/// 默认 GCP 网络范围
|
||||
/// Returns a set of default GCP VPC ranges as a fallback.
|
||||
fn default_gcp_network_ranges() -> Result<Vec<ipnetwork::IpNetwork>, AppError> {
|
||||
// GCP VPC 网络的常见 IP 范围
|
||||
let ranges = vec![
|
||||
"10.0.0.0/8", // 大型 VPC 网络
|
||||
"172.16.0.0/12", // 中型 VPC 网络
|
||||
"192.168.0.0/16", // 小型 VPC 网络
|
||||
"100.64.0.0/10", // GCP 保留范围
|
||||
"10.0.0.0/8", // Large VPCs
|
||||
"172.16.0.0/12", // Medium VPCs
|
||||
"192.168.0.0/16", // Small VPCs
|
||||
"100.64.0.0/10", // GCP reserved range
|
||||
];
|
||||
|
||||
let networks: Result<Vec<_>, _> = ranges.into_iter().map(|s| ipnetwork::IpNetwork::from_str(s)).collect();
|
||||
|
||||
match networks {
|
||||
Ok(networks) => {
|
||||
debug!("Using default GCP network ranges");
|
||||
debug!("Using default GCP VPC network ranges");
|
||||
Ok(networks)
|
||||
}
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse default GCP network ranges: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
/// 从子网名称提取网络信息
|
||||
fn extract_network_from_subnet_name(subnet_name: &str) -> Option<ipnetwork::IpNetwork> {
|
||||
// GCP 子网名称格式通常为:regions/{region}/subnetworks/{subnet-name}
|
||||
// 或者 projects/{project}/regions/{region}/subnetworks/{subnet-name}
|
||||
|
||||
// 尝试从子网名称中提取 IP 范围
|
||||
// 这只是一个简化的实现,实际可能需要查询 GCP API
|
||||
|
||||
// 常见的 GCP 子网 IP 范围模式
|
||||
let patterns = [("10.", 8), ("172.16.", 12), ("192.168.", 16)];
|
||||
|
||||
for (prefix, prefix_len) in patterns {
|
||||
if subnet_name.contains(&format!("subnet-{}", prefix.replace(".", "-"))) {
|
||||
let cidr = format!("{}{}", prefix, "0.0.0/".to_string() + &prefix_len.to_string());
|
||||
if let Ok(network) = ipnetwork::IpNetwork::from_str(&cidr) {
|
||||
return Some(network);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Cloud provider IP range definitions
|
||||
//! Static and dynamic IP range definitions for various cloud providers.
|
||||
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
@@ -23,11 +23,11 @@ use tracing::{debug, info};
|
||||
|
||||
use crate::error::AppError;
|
||||
|
||||
/// Cloudflare IP 范围
|
||||
/// Utility for fetching Cloudflare IP ranges.
|
||||
pub struct CloudflareIpRanges;
|
||||
|
||||
impl CloudflareIpRanges {
|
||||
/// 获取 Cloudflare IP 范围
|
||||
/// Returns a static list of Cloudflare IP ranges.
|
||||
pub async fn fetch() -> Result<Vec<IpNetwork>, AppError> {
|
||||
let ranges = vec![
|
||||
// IPv4 ranges
|
||||
@@ -60,14 +60,14 @@ impl CloudflareIpRanges {
|
||||
|
||||
match networks {
|
||||
Ok(networks) => {
|
||||
info!("Loaded {} Cloudflare IP ranges", networks.len());
|
||||
info!("Loaded {} static Cloudflare IP ranges", networks.len());
|
||||
Ok(networks)
|
||||
}
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse Cloudflare IP ranges: {}", e))),
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse static Cloudflare IP ranges: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 Cloudflare API 获取 IP 范围
|
||||
/// Fetches the latest Cloudflare IP ranges from their official API.
|
||||
pub async fn fetch_from_api() -> Result<Vec<IpNetwork>, AppError> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
@@ -91,7 +91,7 @@ impl CloudflareIpRanges {
|
||||
.lines()
|
||||
.map(|line| line.trim())
|
||||
.filter(|line| !line.is_empty())
|
||||
.map(|line| IpNetwork::from_str(line))
|
||||
.map(IpNetwork::from_str)
|
||||
.collect();
|
||||
|
||||
match ranges {
|
||||
@@ -104,7 +104,7 @@ impl CloudflareIpRanges {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("Failed to fetch IP ranges from {}: {}", url, response.status());
|
||||
debug!("Failed to fetch IP ranges from {}: HTTP {}", url, response.status());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -114,24 +114,23 @@ impl CloudflareIpRanges {
|
||||
}
|
||||
|
||||
if all_ranges.is_empty() {
|
||||
// 如果 API 失败,回退到静态列表
|
||||
// Fallback to static list if API requests fail.
|
||||
Self::fetch().await
|
||||
} else {
|
||||
info!("Fetched {} Cloudflare IP ranges from API", all_ranges.len());
|
||||
info!("Successfully fetched {} Cloudflare IP ranges from API", all_ranges.len());
|
||||
Ok(all_ranges)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DigitalOcean IP 范围
|
||||
/// Utility for fetching DigitalOcean IP ranges.
|
||||
pub struct DigitalOceanIpRanges;
|
||||
|
||||
impl DigitalOceanIpRanges {
|
||||
/// 获取 DigitalOcean IP 范围
|
||||
/// Returns a static list of DigitalOcean IP ranges.
|
||||
pub async fn fetch() -> Result<Vec<IpNetwork>, AppError> {
|
||||
// DigitalOcean 的 IP 范围相对稳定,使用静态列表
|
||||
let ranges = vec![
|
||||
// 数据中心 IP 范围
|
||||
// Datacenter IP ranges
|
||||
"64.227.0.0/16",
|
||||
"138.197.0.0/16",
|
||||
"139.59.0.0/16",
|
||||
@@ -142,7 +141,7 @@ impl DigitalOceanIpRanges {
|
||||
"206.189.0.0/16",
|
||||
"207.154.0.0/16",
|
||||
"209.97.0.0/16",
|
||||
// 负载均衡器 IP 范围
|
||||
// Load Balancer IP ranges
|
||||
"144.126.0.0/16",
|
||||
"143.198.0.0/16",
|
||||
"161.35.0.0/16",
|
||||
@@ -152,19 +151,19 @@ impl DigitalOceanIpRanges {
|
||||
|
||||
match networks {
|
||||
Ok(networks) => {
|
||||
info!("Loaded {} DigitalOcean IP ranges", networks.len());
|
||||
info!("Loaded {} static DigitalOcean IP ranges", networks.len());
|
||||
Ok(networks)
|
||||
}
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse DigitalOcean IP ranges: {}", e))),
|
||||
Err(e) => Err(AppError::cloud(format!("Failed to parse static DigitalOcean IP ranges: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Google Cloud IP 范围
|
||||
/// Utility for fetching Google Cloud IP ranges.
|
||||
pub struct GoogleCloudIpRanges;
|
||||
|
||||
impl GoogleCloudIpRanges {
|
||||
/// 从 Google API 获取 IP 范围
|
||||
/// Fetches the latest Google Cloud IP ranges from their official source.
|
||||
pub async fn fetch() -> Result<Vec<IpNetwork>, AppError> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
@@ -181,7 +180,6 @@ impl GoogleCloudIpRanges {
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct GooglePrefix {
|
||||
ipv4_prefix: Option<String>,
|
||||
ipv6_prefix: Option<String>,
|
||||
}
|
||||
|
||||
match client.get(url).send().await {
|
||||
@@ -190,7 +188,7 @@ impl GoogleCloudIpRanges {
|
||||
let ip_ranges: GoogleIpRanges = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse Google IP ranges: {}", e)))?;
|
||||
.map_err(|e| AppError::cloud(format!("Failed to parse Google IP ranges JSON: {}", e)))?;
|
||||
|
||||
let mut networks = Vec::new();
|
||||
|
||||
@@ -202,10 +200,10 @@ impl GoogleCloudIpRanges {
|
||||
}
|
||||
}
|
||||
|
||||
info!("Fetched {} Google Cloud IP ranges from API", networks.len());
|
||||
info!("Successfully fetched {} Google Cloud IP ranges from API", networks.len());
|
||||
Ok(networks)
|
||||
} else {
|
||||
debug!("Failed to fetch Google IP ranges: {}", response.status());
|
||||
debug!("Failed to fetch Google IP ranges: HTTP {}", response.status());
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,96 +12,115 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Environment variable configuration constants and helpers
|
||||
//! Environment variable configuration constants and helpers for the trusted proxy system.
|
||||
|
||||
use crate::error::ConfigError;
|
||||
use ipnetwork::IpNetwork;
|
||||
use std::str::FromStr;
|
||||
|
||||
// ==================== 代理基础配置 ====================
|
||||
/// 代理验证模式
|
||||
// ==================== Base Proxy Configuration ====================
|
||||
/// Environment variable for the proxy validation mode.
|
||||
pub const ENV_PROXY_VALIDATION_MODE: &str = "TRUSTED_PROXY_VALIDATION_MODE";
|
||||
/// Default validation mode is "hop_by_hop".
|
||||
pub const DEFAULT_PROXY_VALIDATION_MODE: &str = "hop_by_hop";
|
||||
|
||||
/// 是否启用 RFC 7239 Forwarded 头部
|
||||
/// Environment variable to enable RFC 7239 "Forwarded" header support.
|
||||
pub const ENV_PROXY_ENABLE_RFC7239: &str = "TRUSTED_PROXY_ENABLE_RFC7239";
|
||||
/// RFC 7239 support is enabled by default.
|
||||
pub const DEFAULT_PROXY_ENABLE_RFC7239: bool = true;
|
||||
|
||||
/// 最大代理跳数
|
||||
/// Environment variable for the maximum allowed proxy hops.
|
||||
pub const ENV_PROXY_MAX_HOPS: &str = "TRUSTED_PROXY_MAX_HOPS";
|
||||
/// Default maximum hops is 10.
|
||||
pub const DEFAULT_PROXY_MAX_HOPS: usize = 10;
|
||||
|
||||
/// 是否启用链连续性检查
|
||||
/// Environment variable to enable proxy chain continuity checks.
|
||||
pub const ENV_PROXY_CHAIN_CONTINUITY_CHECK: &str = "TRUSTED_PROXY_CHAIN_CONTINUITY_CHECK";
|
||||
/// Continuity checks are enabled by default.
|
||||
pub const DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK: bool = true;
|
||||
|
||||
/// 是否记录验证失败的请求
|
||||
/// Environment variable to enable logging of failed proxy validations.
|
||||
pub const ENV_PROXY_LOG_FAILED_VALIDATIONS: &str = "TRUSTED_PROXY_LOG_FAILED_VALIDATIONS";
|
||||
/// Logging of failed validations is enabled by default.
|
||||
pub const DEFAULT_PROXY_LOG_FAILED_VALIDATIONS: bool = true;
|
||||
|
||||
// ==================== 可信代理配置 ====================
|
||||
/// 基础可信代理列表(逗号分隔的 IP/CIDR)
|
||||
// ==================== Trusted Proxy Networks ====================
|
||||
/// Environment variable for the list of trusted proxy networks (comma-separated IP/CIDR).
|
||||
pub const ENV_TRUSTED_PROXIES: &str = "TRUSTED_PROXY_NETWORKS";
|
||||
/// Default trusted networks include localhost and common private ranges.
|
||||
pub const DEFAULT_TRUSTED_PROXIES: &str = "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8";
|
||||
|
||||
/// 额外可信代理列表(生产环境专用,可覆盖)
|
||||
/// Environment variable for additional trusted proxy networks (production specific).
|
||||
pub const ENV_EXTRA_TRUSTED_PROXIES: &str = "TRUSTED_PROXY_EXTRA_NETWORKS";
|
||||
/// No extra trusted networks by default.
|
||||
pub const DEFAULT_EXTRA_TRUSTED_PROXIES: &str = "";
|
||||
|
||||
/// 私有网络范围(用于内部代理验证)
|
||||
/// Environment variable for private network ranges used in internal validation.
|
||||
pub const ENV_PRIVATE_NETWORKS: &str = "TRUSTED_PROXY_PRIVATE_NETWORKS";
|
||||
/// Default private networks include common RFC 1918 and RFC 4193 ranges.
|
||||
pub const DEFAULT_PRIVATE_NETWORKS: &str = "10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8";
|
||||
|
||||
// ==================== 缓存配置 ====================
|
||||
/// 缓存容量
|
||||
// ==================== Cache Configuration ====================
|
||||
/// Environment variable for the proxy validation cache capacity.
|
||||
pub const ENV_CACHE_CAPACITY: &str = "TRUSTED_PROXY_CACHE_CAPACITY";
|
||||
/// Default cache capacity is 10,000 entries.
|
||||
pub const DEFAULT_CACHE_CAPACITY: usize = 10_000;
|
||||
|
||||
/// 缓存 TTL(秒)
|
||||
/// Environment variable for the cache entry time-to-live (TTL) in seconds.
|
||||
pub const ENV_CACHE_TTL_SECONDS: &str = "TRUSTED_PROXY_CACHE_TTL_SECONDS";
|
||||
/// Default cache TTL is 300 seconds (5 minutes).
|
||||
pub const DEFAULT_CACHE_TTL_SECONDS: u64 = 300;
|
||||
|
||||
/// 缓存清理间隔(秒)
|
||||
/// Environment variable for the cache cleanup interval in seconds.
|
||||
pub const ENV_CACHE_CLEANUP_INTERVAL: &str = "TRUSTED_PROXY_CACHE_CLEANUP_INTERVAL";
|
||||
/// Default cleanup interval is 60 seconds.
|
||||
pub const DEFAULT_CACHE_CLEANUP_INTERVAL: u64 = 60;
|
||||
|
||||
// ==================== 监控配置 ====================
|
||||
/// 是否启用监控指标
|
||||
// ==================== Monitoring Configuration ====================
|
||||
/// Environment variable to enable Prometheus metrics.
|
||||
pub const ENV_METRICS_ENABLED: &str = "TRUSTED_PROXY_METRICS_ENABLED";
|
||||
/// Metrics are enabled by default.
|
||||
pub const DEFAULT_METRICS_ENABLED: bool = true;
|
||||
|
||||
/// 日志级别
|
||||
/// Environment variable for the application log level.
|
||||
pub const ENV_LOG_LEVEL: &str = "TRUSTED_PROXY_LOG_LEVEL";
|
||||
/// Default log level is "info".
|
||||
pub const DEFAULT_LOG_LEVEL: &str = "info";
|
||||
|
||||
/// 是否启用结构化日志
|
||||
/// Environment variable to enable structured JSON logging.
|
||||
pub const ENV_STRUCTURED_LOGGING: &str = "TRUSTED_PROXY_STRUCTURED_LOGGING";
|
||||
/// Structured logging is disabled by default.
|
||||
pub const DEFAULT_STRUCTURED_LOGGING: bool = false;
|
||||
|
||||
/// 是否启用请求追踪
|
||||
/// Environment variable to enable distributed tracing.
|
||||
pub const ENV_TRACING_ENABLED: &str = "TRUSTED_PROXY_TRACING_ENABLED";
|
||||
/// Tracing is enabled by default.
|
||||
pub const DEFAULT_TRACING_ENABLED: bool = true;
|
||||
|
||||
// ==================== 云服务集成 ====================
|
||||
/// 是否启用云元数据获取
|
||||
// ==================== Cloud Integration ====================
|
||||
/// Environment variable to enable automatic cloud metadata discovery.
|
||||
pub const ENV_CLOUD_METADATA_ENABLED: &str = "TRUSTED_PROXY_CLOUD_METADATA_ENABLED";
|
||||
/// Cloud metadata discovery is disabled by default.
|
||||
pub const DEFAULT_CLOUD_METADATA_ENABLED: bool = false;
|
||||
|
||||
/// 云元数据获取超时(秒)
|
||||
/// Environment variable for the cloud metadata request timeout in seconds.
|
||||
pub const ENV_CLOUD_METADATA_TIMEOUT: &str = "TRUSTED_PROXY_CLOUD_METADATA_TIMEOUT";
|
||||
/// Default cloud metadata timeout is 5 seconds.
|
||||
pub const DEFAULT_CLOUD_METADATA_TIMEOUT: u64 = 5;
|
||||
|
||||
/// 是否启用 Cloudflare IP 范围
|
||||
/// Environment variable to enable Cloudflare IP range integration.
|
||||
pub const ENV_CLOUDFLARE_IPS_ENABLED: &str = "TRUSTED_PROXY_CLOUDFLARE_IPS_ENABLED";
|
||||
/// Cloudflare integration is disabled by default.
|
||||
pub const DEFAULT_CLOUDFLARE_IPS_ENABLED: bool = false;
|
||||
|
||||
/// 强制指定的云服务商(覆盖自动检测)
|
||||
/// Environment variable to force a specific cloud provider (overrides auto-detection).
|
||||
pub const ENV_CLOUD_PROVIDER_FORCE: &str = "TRUSTED_PROXY_CLOUD_PROVIDER_FORCE";
|
||||
/// No forced provider by default.
|
||||
pub const DEFAULT_CLOUD_PROVIDER_FORCE: &str = "";
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
// ==================== Helper Functions ====================
|
||||
|
||||
/// 从环境变量解析逗号分隔的IP/CIDR列表
|
||||
/// Parses a comma-separated list of IP/CIDR strings from an environment variable.
|
||||
pub fn parse_ip_list_from_env(key: &str, default: &str) -> Result<Vec<IpNetwork>, ConfigError> {
|
||||
let value = std::env::var(key).unwrap_or_else(|_| default.to_string());
|
||||
|
||||
@@ -119,7 +138,7 @@ pub fn parse_ip_list_from_env(key: &str, default: &str) -> Result<Vec<IpNetwork>
|
||||
match IpNetwork::from_str(item) {
|
||||
Ok(network) => networks.push(network),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse network '{}' from {}: {}", item, key, e);
|
||||
tracing::warn!("Failed to parse network '{}' from environment variable {}: {}", item, key, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -127,7 +146,7 @@ pub fn parse_ip_list_from_env(key: &str, default: &str) -> Result<Vec<IpNetwork>
|
||||
Ok(networks)
|
||||
}
|
||||
|
||||
/// 从环境变量解析逗号分隔的字符串列表
|
||||
/// Parses a comma-separated list of strings from an environment variable.
|
||||
pub fn parse_string_list_from_env(key: &str, default: &str) -> Vec<String> {
|
||||
let value = std::env::var(key).unwrap_or_else(|_| default.to_string());
|
||||
|
||||
@@ -138,7 +157,7 @@ pub fn parse_string_list_from_env(key: &str, default: &str) -> Vec<String> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 从环境变量获取布尔值
|
||||
/// Retrieves a boolean value from an environment variable.
|
||||
pub fn get_bool_from_env(key: &str, default: bool) -> bool {
|
||||
std::env::var(key)
|
||||
.map(|v| match v.to_lowercase().as_str() {
|
||||
@@ -149,27 +168,27 @@ pub fn get_bool_from_env(key: &str, default: bool) -> bool {
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
/// 从环境变量获取整数值
|
||||
/// Retrieves a `usize` value from an environment variable.
|
||||
pub fn get_usize_from_env(key: &str, default: usize) -> usize {
|
||||
std::env::var(key).ok().and_then(|v| v.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
/// 从环境变量获取 u64 值
|
||||
/// Retrieves a `u64` value from an environment variable.
|
||||
pub fn get_u64_from_env(key: &str, default: u64) -> u64 {
|
||||
std::env::var(key).ok().and_then(|v| v.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
/// 从环境变量获取字符串值
|
||||
/// Retrieves a string value from an environment variable.
|
||||
pub fn get_string_from_env(key: &str, default: &str) -> String {
|
||||
std::env::var(key).unwrap_or_else(|_| default.to_string())
|
||||
}
|
||||
|
||||
/// 检查环境变量是否已设置
|
||||
/// Checks if an environment variable is set.
|
||||
pub fn is_env_set(key: &str) -> bool {
|
||||
std::env::var(key).is_ok()
|
||||
}
|
||||
|
||||
/// 获取所有与可信代理相关的环境变量(用于调试)
|
||||
/// Returns a list of all proxy-related environment variables and their current values.
|
||||
pub fn get_all_proxy_env_vars() -> Vec<(String, String)> {
|
||||
let vars = [
|
||||
ENV_PROXY_VALIDATION_MODE,
|
||||
|
||||
@@ -12,58 +12,57 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Configuration loader for environment variables and files
|
||||
//! Configuration loader for environment variables and files.
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::config::env::*;
|
||||
use crate::config::{AppConfig, CacheConfig, CloudConfig, MonitoringConfig, TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
use crate::error::ConfigError;
|
||||
use rustfs_utils::*;
|
||||
|
||||
/// 配置加载器
|
||||
/// Loader for application configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfigLoader;
|
||||
|
||||
impl ConfigLoader {
|
||||
/// 从环境变量加载完整应用配置
|
||||
/// Loads the complete application configuration from environment variables.
|
||||
pub fn from_env() -> Result<AppConfig, ConfigError> {
|
||||
// 加载可信代理配置
|
||||
// Load proxy-specific configuration.
|
||||
let proxy_config = Self::load_proxy_config()?;
|
||||
|
||||
// 加载缓存配置
|
||||
// Load cache configuration.
|
||||
let cache_config = Self::load_cache_config();
|
||||
|
||||
// 加载监控配置
|
||||
// Load monitoring and observability configuration.
|
||||
let monitoring_config = Self::load_monitoring_config();
|
||||
|
||||
// 加载云服务配置
|
||||
// Load cloud provider integration configuration.
|
||||
let cloud_config = Self::load_cloud_config();
|
||||
|
||||
// 服务器地址
|
||||
// Load server binding address.
|
||||
let server_addr = Self::load_server_addr();
|
||||
|
||||
Ok(AppConfig::new(proxy_config, cache_config, monitoring_config, cloud_config, server_addr))
|
||||
}
|
||||
|
||||
/// 加载可信代理配置
|
||||
/// Loads trusted proxy configuration from environment variables.
|
||||
fn load_proxy_config() -> Result<TrustedProxyConfig, ConfigError> {
|
||||
// 解析可信代理列表
|
||||
let mut proxies = Vec::new();
|
||||
|
||||
// 基础可信代理
|
||||
// Parse base trusted proxies from environment.
|
||||
let base_networks = parse_ip_list_from_env(ENV_TRUSTED_PROXIES, DEFAULT_TRUSTED_PROXIES)?;
|
||||
for network in base_networks {
|
||||
proxies.push(TrustedProxy::Cidr(network));
|
||||
}
|
||||
|
||||
// 额外可信代理
|
||||
// Parse extra trusted proxies from environment.
|
||||
let extra_networks = parse_ip_list_from_env(ENV_EXTRA_TRUSTED_PROXIES, DEFAULT_EXTRA_TRUSTED_PROXIES)?;
|
||||
for network in extra_networks {
|
||||
proxies.push(TrustedProxy::Cidr(network));
|
||||
}
|
||||
|
||||
// 单个 IP(从环境变量解析)
|
||||
// Parse individual trusted proxy IPs.
|
||||
let ip_strings = parse_string_list_from_env("TRUSTED_PROXY_IPS", "");
|
||||
for ip_str in ip_strings {
|
||||
if let Ok(ip) = ip_str.parse::<IpAddr>() {
|
||||
@@ -71,16 +70,16 @@ impl ConfigLoader {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证模式
|
||||
let validation_mode_str = get_string_from_env(ENV_PROXY_VALIDATION_MODE, DEFAULT_PROXY_VALIDATION_MODE);
|
||||
// Determine validation mode.
|
||||
let validation_mode_str = get_env_str(ENV_PROXY_VALIDATION_MODE, DEFAULT_PROXY_VALIDATION_MODE);
|
||||
let validation_mode = ValidationMode::from_str(&validation_mode_str)?;
|
||||
|
||||
// 其他配置
|
||||
let enable_rfc7239 = get_bool_from_env(ENV_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_ENABLE_RFC7239);
|
||||
let max_hops = get_usize_from_env(ENV_PROXY_MAX_HOPS, DEFAULT_PROXY_MAX_HOPS);
|
||||
let enable_chain_check = get_bool_from_env(ENV_PROXY_CHAIN_CONTINUITY_CHECK, DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK);
|
||||
// Load other proxy settings.
|
||||
let enable_rfc7239 = get_env_bool(ENV_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_ENABLE_RFC7239);
|
||||
let max_hops = get_env_usize(ENV_PROXY_MAX_HOPS, DEFAULT_PROXY_MAX_HOPS);
|
||||
let enable_chain_check = get_env_bool(ENV_PROXY_CHAIN_CONTINUITY_CHECK, DEFAULT_PROXY_CHAIN_CONTINUITY_CHECK);
|
||||
|
||||
// 私有网络
|
||||
// Load private network ranges.
|
||||
let private_networks = parse_ip_list_from_env(ENV_PRIVATE_NETWORKS, DEFAULT_PRIVATE_NETWORKS)?;
|
||||
|
||||
Ok(TrustedProxyConfig::new(
|
||||
@@ -93,29 +92,29 @@ impl ConfigLoader {
|
||||
))
|
||||
}
|
||||
|
||||
/// 加载缓存配置
|
||||
/// Loads cache configuration from environment variables.
|
||||
fn load_cache_config() -> CacheConfig {
|
||||
CacheConfig {
|
||||
capacity: get_usize_from_env(ENV_CACHE_CAPACITY, DEFAULT_CACHE_CAPACITY),
|
||||
ttl_seconds: get_u64_from_env(ENV_CACHE_TTL_SECONDS, DEFAULT_CACHE_TTL_SECONDS),
|
||||
cleanup_interval_seconds: get_u64_from_env(ENV_CACHE_CLEANUP_INTERVAL, DEFAULT_CACHE_CLEANUP_INTERVAL),
|
||||
capacity: get_env_usize(ENV_CACHE_CAPACITY, DEFAULT_CACHE_CAPACITY),
|
||||
ttl_seconds: get_env_u64(ENV_CACHE_TTL_SECONDS, DEFAULT_CACHE_TTL_SECONDS),
|
||||
cleanup_interval_seconds: get_env_u64(ENV_CACHE_CLEANUP_INTERVAL, DEFAULT_CACHE_CLEANUP_INTERVAL),
|
||||
}
|
||||
}
|
||||
|
||||
/// 加载监控配置
|
||||
/// Loads monitoring configuration from environment variables.
|
||||
fn load_monitoring_config() -> MonitoringConfig {
|
||||
MonitoringConfig {
|
||||
metrics_enabled: get_bool_from_env(ENV_METRICS_ENABLED, DEFAULT_METRICS_ENABLED),
|
||||
log_level: get_string_from_env(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
|
||||
structured_logging: get_bool_from_env(ENV_STRUCTURED_LOGGING, DEFAULT_STRUCTURED_LOGGING),
|
||||
tracing_enabled: get_bool_from_env(ENV_TRACING_ENABLED, DEFAULT_TRACING_ENABLED),
|
||||
log_failed_validations: get_bool_from_env(ENV_PROXY_LOG_FAILED_VALIDATIONS, DEFAULT_PROXY_LOG_FAILED_VALIDATIONS),
|
||||
metrics_enabled: get_env_bool(ENV_METRICS_ENABLED, DEFAULT_METRICS_ENABLED),
|
||||
log_level: get_env_str(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
|
||||
structured_logging: get_env_bool(ENV_STRUCTURED_LOGGING, DEFAULT_STRUCTURED_LOGGING),
|
||||
tracing_enabled: get_env_bool(ENV_TRACING_ENABLED, DEFAULT_TRACING_ENABLED),
|
||||
log_failed_validations: get_env_bool(ENV_PROXY_LOG_FAILED_VALIDATIONS, DEFAULT_PROXY_LOG_FAILED_VALIDATIONS),
|
||||
}
|
||||
}
|
||||
|
||||
/// 加载云服务配置
|
||||
/// Loads cloud configuration from environment variables.
|
||||
fn load_cloud_config() -> CloudConfig {
|
||||
let forced_provider_str = get_string_from_env(ENV_CLOUD_PROVIDER_FORCE, DEFAULT_CLOUD_PROVIDER_FORCE);
|
||||
let forced_provider_str = get_env_str(ENV_CLOUD_PROVIDER_FORCE, DEFAULT_CLOUD_PROVIDER_FORCE);
|
||||
let forced_provider = if forced_provider_str.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -123,24 +122,24 @@ impl ConfigLoader {
|
||||
};
|
||||
|
||||
CloudConfig {
|
||||
metadata_enabled: get_bool_from_env(ENV_CLOUD_METADATA_ENABLED, DEFAULT_CLOUD_METADATA_ENABLED),
|
||||
metadata_timeout_seconds: get_u64_from_env(ENV_CLOUD_METADATA_TIMEOUT, DEFAULT_CLOUD_METADATA_TIMEOUT),
|
||||
cloudflare_ips_enabled: get_bool_from_env(ENV_CLOUDFLARE_IPS_ENABLED, DEFAULT_CLOUDFLARE_IPS_ENABLED),
|
||||
metadata_enabled: get_env_bool(ENV_CLOUD_METADATA_ENABLED, DEFAULT_CLOUD_METADATA_ENABLED),
|
||||
metadata_timeout_seconds: get_env_u64(ENV_CLOUD_METADATA_TIMEOUT, DEFAULT_CLOUD_METADATA_TIMEOUT),
|
||||
cloudflare_ips_enabled: get_env_bool(ENV_CLOUDFLARE_IPS_ENABLED, DEFAULT_CLOUDFLARE_IPS_ENABLED),
|
||||
forced_provider,
|
||||
}
|
||||
}
|
||||
|
||||
/// 加载服务器地址
|
||||
/// Loads the server binding address from environment variables.
|
||||
fn load_server_addr() -> SocketAddr {
|
||||
let host = get_string_from_env("SERVER_HOST", "0.0.0.0");
|
||||
let port = get_usize_from_env("SERVER_PORT", 3000) as u16;
|
||||
let host = get_env_str("SERVER_HOST", "0.0.0.0");
|
||||
let port = get_env_usize("SERVER_PORT", 3000) as u16;
|
||||
|
||||
format!("{}:{}", host, port)
|
||||
.parse()
|
||||
.unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 3000)))
|
||||
}
|
||||
|
||||
/// 从环境变量加载配置,如果失败则使用默认值
|
||||
/// Loads configuration from environment, falling back to defaults on failure.
|
||||
pub fn from_env_or_default() -> AppConfig {
|
||||
match Self::from_env() {
|
||||
Ok(config) => {
|
||||
@@ -154,9 +153,8 @@ impl ConfigLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建默认配置
|
||||
/// Returns a default configuration.
|
||||
pub fn default_config() -> AppConfig {
|
||||
// 默认可信代理配置
|
||||
let proxy_config = TrustedProxyConfig::new(
|
||||
vec![
|
||||
TrustedProxy::Single("127.0.0.1".parse().unwrap()),
|
||||
@@ -173,7 +171,6 @@ impl ConfigLoader {
|
||||
],
|
||||
);
|
||||
|
||||
// 默认应用配置
|
||||
AppConfig::new(
|
||||
proxy_config,
|
||||
CacheConfig::default(),
|
||||
@@ -183,7 +180,7 @@ impl ConfigLoader {
|
||||
)
|
||||
}
|
||||
|
||||
/// 打印配置摘要
|
||||
/// Prints a summary of the configuration to the log.
|
||||
pub fn print_summary(config: &AppConfig) {
|
||||
tracing::info!("=== Application Configuration ===");
|
||||
tracing::info!("Server: {}", config.server_addr);
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Configuration type definitions
|
||||
//! Configuration type definitions for the trusted proxy system.
|
||||
|
||||
use ipnetwork::IpNetwork;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -21,25 +21,26 @@ use std::time::Duration;
|
||||
|
||||
use crate::error::ConfigError;
|
||||
|
||||
/// 代理验证模式
|
||||
/// Proxy validation mode defining how the proxy chain is verified.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ValidationMode {
|
||||
/// 宽松模式:只要最后一个代理可信,就接受整个链
|
||||
/// Lenient mode: Accepts the entire chain as long as the last proxy is trusted.
|
||||
Lenient,
|
||||
/// 严格模式:要求链中所有代理都可信
|
||||
/// Strict mode: Requires all proxies in the chain to be trusted.
|
||||
Strict,
|
||||
/// 跳数验证模式:从右向左找到第一个不可信代理
|
||||
/// Hop-by-hop mode: Finds the first untrusted proxy from right to left.
|
||||
/// This is the recommended mode for most production environments.
|
||||
HopByHop,
|
||||
}
|
||||
|
||||
impl ValidationMode {
|
||||
/// 从字符串解析验证模式
|
||||
/// Parses the validation mode from a string.
|
||||
pub fn from_str(s: &str) -> Result<Self, ConfigError> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"lenient" => Ok(Self::Lenient),
|
||||
"strict" => Ok(Self::Strict),
|
||||
"hop_by_hop" => Ok(Self::HopByHop),
|
||||
"hop_by_hop" | "hopbyhop" => Ok(Self::HopByHop),
|
||||
_ => Err(ConfigError::InvalidConfig(format!(
|
||||
"Invalid validation mode: '{}'. Must be one of: lenient, strict, hop_by_hop",
|
||||
s
|
||||
@@ -47,7 +48,7 @@ impl ValidationMode {
|
||||
}
|
||||
}
|
||||
|
||||
/// 转换为字符串
|
||||
/// Returns the string representation of the validation mode.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Lenient => "lenient",
|
||||
@@ -63,17 +64,17 @@ impl Default for ValidationMode {
|
||||
}
|
||||
}
|
||||
|
||||
/// 可信代理类型
|
||||
#[derive(Debug, Clone)]
|
||||
/// Represents a trusted proxy entry, which can be a single IP or a CIDR range.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TrustedProxy {
|
||||
/// 单个 IP 地址
|
||||
/// A single IP address.
|
||||
Single(IpAddr),
|
||||
/// IP 地址段 (CIDR 表示法)
|
||||
/// An IP network range (CIDR notation).
|
||||
Cidr(IpNetwork),
|
||||
}
|
||||
|
||||
impl TrustedProxy {
|
||||
/// 检查 IP 是否匹配此代理配置
|
||||
/// Checks if the given IP address matches this proxy configuration.
|
||||
pub fn contains(&self, ip: &IpAddr) -> bool {
|
||||
match self {
|
||||
Self::Single(proxy_ip) => ip == proxy_ip,
|
||||
@@ -81,7 +82,7 @@ impl TrustedProxy {
|
||||
}
|
||||
}
|
||||
|
||||
/// 转换为字符串表示
|
||||
/// Returns the string representation of the proxy entry.
|
||||
pub fn to_string(&self) -> String {
|
||||
match self {
|
||||
Self::Single(ip) => ip.to_string(),
|
||||
@@ -90,25 +91,25 @@ impl TrustedProxy {
|
||||
}
|
||||
}
|
||||
|
||||
/// 可信代理配置
|
||||
/// Configuration for trusted proxies and validation logic.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrustedProxyConfig {
|
||||
/// 代理列表
|
||||
/// List of trusted proxy entries.
|
||||
pub proxies: Vec<TrustedProxy>,
|
||||
/// 验证模式
|
||||
/// The validation mode to use for verifying proxy chains.
|
||||
pub validation_mode: ValidationMode,
|
||||
/// 是否启用 RFC 7239 Forwarded 头部
|
||||
/// Whether to enable RFC 7239 "Forwarded" header support.
|
||||
pub enable_rfc7239: bool,
|
||||
/// 最大代理跳数
|
||||
/// Maximum allowed proxy hops in the chain.
|
||||
pub max_hops: usize,
|
||||
/// 是否启用链连续性检查
|
||||
/// Whether to enable continuity checks for the proxy chain.
|
||||
pub enable_chain_continuity_check: bool,
|
||||
/// 私有网络范围
|
||||
/// Private network ranges that should be treated with caution.
|
||||
pub private_networks: Vec<IpNetwork>,
|
||||
}
|
||||
|
||||
impl TrustedProxyConfig {
|
||||
/// 创建新配置
|
||||
/// Creates a new trusted proxy configuration.
|
||||
pub fn new(
|
||||
proxies: Vec<TrustedProxy>,
|
||||
validation_mode: ValidationMode,
|
||||
@@ -127,23 +128,23 @@ impl TrustedProxyConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查 SocketAddr 是否来自可信代理
|
||||
/// Checks if a SocketAddr originates from a trusted proxy.
|
||||
pub fn is_trusted(&self, addr: &SocketAddr) -> bool {
|
||||
let ip = addr.ip();
|
||||
self.proxies.iter().any(|proxy| proxy.contains(&ip))
|
||||
}
|
||||
|
||||
/// 检查 IP 是否在私有网络范围内
|
||||
/// Checks if an IP address belongs to a private network range.
|
||||
pub fn is_private_network(&self, ip: &IpAddr) -> bool {
|
||||
self.private_networks.iter().any(|network| network.contains(*ip))
|
||||
}
|
||||
|
||||
/// 获取所有网络范围的字符串表示(用于调试)
|
||||
/// Returns a list of all network strings for debugging purposes.
|
||||
pub fn get_network_strings(&self) -> Vec<String> {
|
||||
self.proxies.iter().map(|p| p.to_string()).collect()
|
||||
}
|
||||
|
||||
/// 获取配置摘要
|
||||
/// Returns a summary of the configuration.
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"TrustedProxyConfig {{ proxies: {}, mode: {}, max_hops: {} }}",
|
||||
@@ -154,14 +155,14 @@ impl TrustedProxyConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 缓存配置
|
||||
/// Configuration for the internal caching mechanism.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CacheConfig {
|
||||
/// 缓存容量
|
||||
/// Maximum number of entries in the cache.
|
||||
pub capacity: usize,
|
||||
/// 缓存 TTL(秒)
|
||||
/// Time-to-live for cache entries in seconds.
|
||||
pub ttl_seconds: u64,
|
||||
/// 缓存清理间隔(秒)
|
||||
/// Interval for cache cleanup in seconds.
|
||||
pub cleanup_interval_seconds: u64,
|
||||
}
|
||||
|
||||
@@ -176,29 +177,29 @@ impl Default for CacheConfig {
|
||||
}
|
||||
|
||||
impl CacheConfig {
|
||||
/// 获取缓存 TTL 时长
|
||||
/// Returns the TTL as a Duration.
|
||||
pub fn ttl_duration(&self) -> Duration {
|
||||
Duration::from_secs(self.ttl_seconds)
|
||||
}
|
||||
|
||||
/// 获取缓存清理间隔时长
|
||||
/// Returns the cleanup interval as a Duration.
|
||||
pub fn cleanup_interval(&self) -> Duration {
|
||||
Duration::from_secs(self.cleanup_interval_seconds)
|
||||
}
|
||||
}
|
||||
|
||||
/// 监控配置
|
||||
/// Configuration for monitoring and observability.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MonitoringConfig {
|
||||
/// 是否启用监控指标
|
||||
/// Whether to enable Prometheus metrics.
|
||||
pub metrics_enabled: bool,
|
||||
/// 日志级别
|
||||
/// The logging level (e.g., "info", "debug").
|
||||
pub log_level: String,
|
||||
/// 是否启用结构化日志
|
||||
/// Whether to use structured JSON logging.
|
||||
pub structured_logging: bool,
|
||||
/// 是否启用请求追踪
|
||||
/// Whether to enable distributed tracing.
|
||||
pub tracing_enabled: bool,
|
||||
/// 是否记录验证失败的请求
|
||||
/// Whether to log detailed information about failed validations.
|
||||
pub log_failed_validations: bool,
|
||||
}
|
||||
|
||||
@@ -214,16 +215,16 @@ impl Default for MonitoringConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 云服务集成配置
|
||||
/// Configuration for cloud provider integration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CloudConfig {
|
||||
/// 是否启用云元数据获取
|
||||
/// Whether to enable automatic cloud metadata discovery.
|
||||
pub metadata_enabled: bool,
|
||||
/// 云元数据获取超时(秒)
|
||||
/// Timeout for cloud metadata requests in seconds.
|
||||
pub metadata_timeout_seconds: u64,
|
||||
/// 是否启用 Cloudflare IP 范围
|
||||
/// Whether to automatically include Cloudflare IP ranges.
|
||||
pub cloudflare_ips_enabled: bool,
|
||||
/// 强制指定的云服务商
|
||||
/// Optionally force a specific cloud provider.
|
||||
pub forced_provider: Option<String>,
|
||||
}
|
||||
|
||||
@@ -239,29 +240,29 @@ impl Default for CloudConfig {
|
||||
}
|
||||
|
||||
impl CloudConfig {
|
||||
/// 获取元数据获取超时时长
|
||||
/// Returns the metadata timeout as a Duration.
|
||||
pub fn metadata_timeout(&self) -> Duration {
|
||||
Duration::from_secs(self.metadata_timeout_seconds)
|
||||
}
|
||||
}
|
||||
|
||||
/// 完整的应用配置
|
||||
/// Complete application configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppConfig {
|
||||
/// 代理配置
|
||||
/// Trusted proxy settings.
|
||||
pub proxy: TrustedProxyConfig,
|
||||
/// 缓存配置
|
||||
/// Cache settings.
|
||||
pub cache: CacheConfig,
|
||||
/// 监控配置
|
||||
/// Monitoring and observability settings.
|
||||
pub monitoring: MonitoringConfig,
|
||||
/// 云服务配置
|
||||
/// Cloud integration settings.
|
||||
pub cloud: CloudConfig,
|
||||
/// 服务器绑定地址
|
||||
/// The address the server should bind to.
|
||||
pub server_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
/// 创建应用配置
|
||||
/// Creates a new application configuration.
|
||||
pub fn new(
|
||||
proxy: TrustedProxyConfig,
|
||||
cache: CacheConfig,
|
||||
@@ -278,7 +279,7 @@ impl AppConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取配置摘要
|
||||
/// Returns a summary of the application configuration.
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"AppConfig {{\n\
|
||||
|
||||
@@ -12,42 +12,42 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Configuration error types
|
||||
//! Configuration error types for the trusted proxy system.
|
||||
|
||||
use std::net::AddrParseError;
|
||||
|
||||
/// 配置错误类型
|
||||
/// Errors related to application configuration.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConfigError {
|
||||
/// 环境变量缺失
|
||||
/// Required environment variable is missing.
|
||||
#[error("Missing environment variable: {0}")]
|
||||
MissingEnvVar(String),
|
||||
|
||||
/// 环境变量解析失败
|
||||
/// Environment variable exists but could not be parsed.
|
||||
#[error("Failed to parse environment variable {0}: {1}")]
|
||||
EnvParseError(String, String),
|
||||
|
||||
/// 无效的配置值
|
||||
/// A configuration value is logically invalid.
|
||||
#[error("Invalid configuration value for {0}: {1}")]
|
||||
InvalidValue(String, String),
|
||||
|
||||
/// 无效的 IP 地址或网络
|
||||
/// An IP address or CIDR range is malformed.
|
||||
#[error("Invalid IP address or network: {0}")]
|
||||
InvalidIp(String),
|
||||
|
||||
/// 配置验证失败
|
||||
/// Configuration failed overall validation.
|
||||
#[error("Configuration validation failed: {0}")]
|
||||
ValidationFailed(String),
|
||||
|
||||
/// 配置冲突
|
||||
/// Two or more configuration settings are in conflict.
|
||||
#[error("Configuration conflict: {0}")]
|
||||
Conflict(String),
|
||||
|
||||
/// 配置文件错误
|
||||
/// Error reading or parsing a configuration file.
|
||||
#[error("Config file error: {0}")]
|
||||
FileError(String),
|
||||
|
||||
/// 无效的配置
|
||||
/// General invalid configuration error.
|
||||
#[error("Invalid config: {0}")]
|
||||
InvalidConfig(String),
|
||||
}
|
||||
@@ -65,17 +65,17 @@ impl From<ipnetwork::IpNetworkError> for ConfigError {
|
||||
}
|
||||
|
||||
impl ConfigError {
|
||||
/// 创建环境变量缺失错误
|
||||
/// Creates a `MissingEnvVar` error.
|
||||
pub fn missing_env_var(key: &str) -> Self {
|
||||
Self::MissingEnvVar(key.to_string())
|
||||
}
|
||||
|
||||
/// 创建环境变量解析错误
|
||||
/// Creates an `EnvParseError`.
|
||||
pub fn env_parse(key: &str, value: &str) -> Self {
|
||||
Self::EnvParseError(key.to_string(), value.to_string())
|
||||
}
|
||||
|
||||
/// 创建无效配置值错误
|
||||
/// Creates an `InvalidValue` error.
|
||||
pub fn invalid_value(field: &str, value: &str) -> Self {
|
||||
Self::InvalidValue(field.to_string(), value.to_string())
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Error types for the trusted proxy system
|
||||
//! Error types for the trusted proxy system.
|
||||
|
||||
mod config;
|
||||
mod proxy;
|
||||
@@ -20,55 +20,55 @@ mod proxy;
|
||||
pub use config::*;
|
||||
pub use proxy::*;
|
||||
|
||||
/// 统一错误类型
|
||||
/// Unified error type for the application.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AppError {
|
||||
/// 配置错误
|
||||
/// Errors related to configuration.
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(#[from] ConfigError),
|
||||
|
||||
/// 代理验证错误
|
||||
/// Errors related to proxy validation.
|
||||
#[error("Proxy validation error: {0}")]
|
||||
Proxy(#[from] ProxyError),
|
||||
|
||||
/// 云服务错误
|
||||
/// Errors related to cloud service integration.
|
||||
#[error("Cloud service error: {0}")]
|
||||
Cloud(String),
|
||||
|
||||
/// 内部错误
|
||||
/// General internal errors.
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
|
||||
/// IO 错误
|
||||
/// Standard I/O errors.
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// HTTP 错误
|
||||
/// Errors related to HTTP requests or responses.
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(String),
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
/// 创建云服务错误
|
||||
/// Creates a new `Cloud` error.
|
||||
pub fn cloud(msg: impl Into<String>) -> Self {
|
||||
Self::Cloud(msg.into())
|
||||
}
|
||||
|
||||
/// 创建内部错误
|
||||
/// Creates a new `Internal` error.
|
||||
pub fn internal(msg: impl Into<String>) -> Self {
|
||||
Self::Internal(msg.into())
|
||||
}
|
||||
|
||||
/// 创建 HTTP 错误
|
||||
/// Creates a new `Http` error.
|
||||
pub fn http(msg: impl Into<String>) -> Self {
|
||||
Self::Http(msg.into())
|
||||
}
|
||||
|
||||
/// 判断错误是否可恢复
|
||||
/// Returns true if the error is considered recoverable.
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
Self::Config(_) => true,
|
||||
Self::Proxy(_) => true,
|
||||
Self::Proxy(e) => e.is_recoverable(),
|
||||
Self::Cloud(_) => true,
|
||||
Self::Internal(_) => false,
|
||||
Self::Io(_) => true,
|
||||
@@ -77,7 +77,7 @@ impl AppError {
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP 响应错误类型
|
||||
/// Type alias for API error responses (Status Code, Error Message).
|
||||
pub type ApiError = (axum::http::StatusCode, String);
|
||||
|
||||
impl From<AppError> for ApiError {
|
||||
|
||||
@@ -12,50 +12,50 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Proxy validation error types
|
||||
//! Proxy validation error types for the trusted proxy system.
|
||||
|
||||
use std::net::AddrParseError;
|
||||
|
||||
/// 代理验证错误类型
|
||||
/// Errors that can occur during proxy chain validation.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProxyError {
|
||||
/// 无效的 X-Forwarded-For 头部
|
||||
/// The X-Forwarded-For header is malformed or contains invalid data.
|
||||
#[error("Invalid X-Forwarded-For header: {0}")]
|
||||
InvalidXForwardedFor(String),
|
||||
|
||||
/// 无效的 Forwarded 头部(RFC 7239)
|
||||
/// The RFC 7239 Forwarded header is malformed.
|
||||
#[error("Invalid Forwarded header (RFC 7239): {0}")]
|
||||
InvalidForwardedHeader(String),
|
||||
|
||||
/// 代理链验证失败
|
||||
/// General failure during proxy chain validation.
|
||||
#[error("Proxy chain validation failed: {0}")]
|
||||
ChainValidationFailed(String),
|
||||
|
||||
/// 代理链过长
|
||||
/// The number of proxy hops exceeds the configured limit.
|
||||
#[error("Proxy chain too long: {0} hops (max: {1})")]
|
||||
ChainTooLong(usize, usize),
|
||||
|
||||
/// 来自不可信代理
|
||||
/// The request originated from a proxy that is not in the trusted list.
|
||||
#[error("Request from untrusted proxy: {0}")]
|
||||
UntrustedProxy(String),
|
||||
|
||||
/// 代理链不连续
|
||||
/// The proxy chain is not continuous (e.g., an untrusted IP is between trusted ones).
|
||||
#[error("Proxy chain is not continuous")]
|
||||
ChainNotContinuous,
|
||||
|
||||
/// IP 地址解析失败
|
||||
/// An IP address in the chain could not be parsed.
|
||||
#[error("Failed to parse IP address: {0}")]
|
||||
IpParseError(String),
|
||||
|
||||
/// 头部解析失败
|
||||
/// A header value could not be parsed as a string.
|
||||
#[error("Failed to parse header: {0}")]
|
||||
HeaderParseError(String),
|
||||
|
||||
/// 验证超时
|
||||
/// Validation took too long and timed out.
|
||||
#[error("Validation timeout")]
|
||||
Timeout,
|
||||
|
||||
/// 内部验证错误
|
||||
/// An unexpected internal error occurred during validation.
|
||||
#[error("Internal validation error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
@@ -67,40 +67,41 @@ impl From<AddrParseError> for ProxyError {
|
||||
}
|
||||
|
||||
impl ProxyError {
|
||||
/// 创建无效 X-Forwarded-For 头部错误
|
||||
/// Creates an `InvalidXForwardedFor` error.
|
||||
pub fn invalid_xff(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidXForwardedFor(msg.into())
|
||||
}
|
||||
|
||||
/// 创建无效 Forwarded 头部错误
|
||||
/// Creates an `InvalidForwardedHeader` error.
|
||||
pub fn invalid_forwarded(msg: impl Into<String>) -> Self {
|
||||
Self::InvalidForwardedHeader(msg.into())
|
||||
}
|
||||
|
||||
/// 创建代理链验证失败错误
|
||||
/// Creates a `ChainValidationFailed` error.
|
||||
pub fn chain_failed(msg: impl Into<String>) -> Self {
|
||||
Self::ChainValidationFailed(msg.into())
|
||||
}
|
||||
|
||||
/// 创建来自不可信代理错误
|
||||
/// Creates an `UntrustedProxy` error.
|
||||
pub fn untrusted(proxy: impl Into<String>) -> Self {
|
||||
Self::UntrustedProxy(proxy.into())
|
||||
}
|
||||
|
||||
/// 创建内部验证错误
|
||||
/// Creates an `Internal` validation error.
|
||||
pub fn internal(msg: impl Into<String>) -> Self {
|
||||
Self::Internal(msg.into())
|
||||
}
|
||||
|
||||
/// 判断错误是否可恢复(是否应该继续处理请求)
|
||||
/// Determines if the error is recoverable, meaning the request can still be processed
|
||||
/// (perhaps by falling back to the direct peer IP).
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
// 这些错误通常意味着我们应该拒绝请求或使用备用 IP
|
||||
// These errors typically mean we should use the direct peer IP as a fallback.
|
||||
Self::UntrustedProxy(_) => true,
|
||||
Self::ChainTooLong(_, _) => true,
|
||||
Self::ChainNotContinuous => true,
|
||||
|
||||
// 这些错误可能意味着配置问题或恶意请求
|
||||
// These errors suggest malformed requests or severe configuration issues.
|
||||
Self::InvalidXForwardedFor(_) => false,
|
||||
Self::InvalidForwardedHeader(_) => false,
|
||||
Self::ChainValidationFailed(_) => false,
|
||||
|
||||
@@ -12,16 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod api;
|
||||
mod cloud;
|
||||
mod config;
|
||||
mod error;
|
||||
mod logging;
|
||||
mod middleware;
|
||||
mod proxy;
|
||||
mod state;
|
||||
mod utils;
|
||||
pub mod api;
|
||||
pub mod cloud;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod logging;
|
||||
pub mod middleware;
|
||||
pub mod proxy;
|
||||
pub mod state;
|
||||
pub mod utils;
|
||||
|
||||
// Re-export core types for convenience
|
||||
pub use cloud::*;
|
||||
pub use config::*;
|
||||
pub use middleware::{ClientInfo, TrustedProxyLayer, TrustedProxyMiddleware};
|
||||
pub use proxy::*;
|
||||
pub use state::AppState;
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Logging middleware for Axum
|
||||
//! Logging middleware for the Axum web framework.
|
||||
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Instant;
|
||||
@@ -21,14 +21,14 @@ use uuid::Uuid;
|
||||
|
||||
use crate::logging::Logger;
|
||||
|
||||
/// 请求日志中间件层
|
||||
/// Tower Layer for request logging middleware.
|
||||
#[derive(Clone)]
|
||||
pub struct RequestLoggingLayer {
|
||||
logger: Logger,
|
||||
}
|
||||
|
||||
impl RequestLoggingLayer {
|
||||
/// 创建新的日志中间件层
|
||||
/// Creates a new `RequestLoggingLayer`.
|
||||
pub fn new(logger: Logger) -> Self {
|
||||
Self { logger }
|
||||
}
|
||||
@@ -45,7 +45,7 @@ impl<S> tower::Layer<S> for RequestLoggingLayer {
|
||||
}
|
||||
}
|
||||
|
||||
/// 请求日志中间件服务
|
||||
/// Tower Service for request logging middleware.
|
||||
#[derive(Clone)]
|
||||
pub struct RequestLoggingMiddleware<S> {
|
||||
inner: S,
|
||||
@@ -70,24 +70,22 @@ where
|
||||
let mut inner = self.inner.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
// 生成请求 ID
|
||||
// Generate a unique request ID for correlation.
|
||||
let request_id = Uuid::new_v4().to_string();
|
||||
|
||||
// 记录请求开始时间和日志
|
||||
let start_time = Instant::now();
|
||||
logger.log_request(&req, &request_id);
|
||||
|
||||
// 将请求 ID 添加到请求扩展中
|
||||
// Inject the request ID into the request extensions.
|
||||
let mut req = req;
|
||||
req.extensions_mut().insert(RequestId(request_id.clone()));
|
||||
|
||||
// 处理请求
|
||||
// Process the request.
|
||||
let result = inner.call(req).await;
|
||||
|
||||
// 计算处理时间
|
||||
let duration = start_time.elapsed();
|
||||
|
||||
// 记录响应
|
||||
// Log the response or error.
|
||||
match &result {
|
||||
Ok(response) => {
|
||||
logger.log_response(response, &request_id, duration);
|
||||
@@ -102,12 +100,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// 请求 ID 包装器
|
||||
/// Wrapper for a unique request ID.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestId(String);
|
||||
|
||||
impl RequestId {
|
||||
/// 获取请求 ID
|
||||
/// Returns the request ID as a string slice.
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
@@ -119,7 +117,7 @@ impl std::fmt::Display for RequestId {
|
||||
}
|
||||
}
|
||||
|
||||
/// 代理特定的日志中间件
|
||||
/// Middleware specifically for logging proxy-related information.
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyLoggingMiddleware<S> {
|
||||
inner: S,
|
||||
@@ -127,7 +125,7 @@ pub struct ProxyLoggingMiddleware<S> {
|
||||
}
|
||||
|
||||
impl<S> ProxyLoggingMiddleware<S> {
|
||||
/// 创建新的代理日志中间件
|
||||
/// Creates a new `ProxyLoggingMiddleware`.
|
||||
pub fn new(inner: S, logger: Logger) -> Self {
|
||||
Self { inner, logger }
|
||||
}
|
||||
@@ -147,7 +145,7 @@ where
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
|
||||
// 记录代理相关信息
|
||||
// Log proxy-specific details if available.
|
||||
let peer_addr = req.extensions().get::<std::net::SocketAddr>().copied();
|
||||
let client_info = req.extensions().get::<crate::middleware::ClientInfo>();
|
||||
|
||||
@@ -155,7 +153,7 @@ where
|
||||
self.logger
|
||||
.log_info(&format!("Proxy request from {}: {}", addr, info.to_log_string()), None);
|
||||
|
||||
// 如果有警告,记录它们
|
||||
// Log any warnings generated during proxy validation.
|
||||
if !info.warnings.is_empty() {
|
||||
for warning in &info.warnings {
|
||||
self.logger.log_warning(warning, Some("proxy_validation"));
|
||||
@@ -167,14 +165,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// 代理日志中间件层
|
||||
/// Tower Layer for proxy logging middleware.
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyLoggingLayer {
|
||||
logger: Logger,
|
||||
}
|
||||
|
||||
impl ProxyLoggingLayer {
|
||||
/// 创建新的代理日志中间件层
|
||||
/// Creates a new `ProxyLoggingLayer`.
|
||||
pub fn new(logger: Logger) -> Self {
|
||||
Self { logger }
|
||||
}
|
||||
|
||||
@@ -12,26 +12,26 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Logging module for structured logging and middleware
|
||||
//! Logging module for structured logging and observability.
|
||||
|
||||
mod middleware;
|
||||
|
||||
pub use middleware::*;
|
||||
|
||||
/// 日志配置
|
||||
/// Configuration for the logging system.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoggingConfig {
|
||||
/// 是否启用结构化日志
|
||||
/// Whether to use structured JSON logging.
|
||||
pub structured: bool,
|
||||
/// 日志级别
|
||||
/// The logging level (e.g., "info", "debug").
|
||||
pub level: String,
|
||||
/// 是否启用请求 ID
|
||||
/// Whether to include a unique request ID in logs.
|
||||
pub enable_request_id: bool,
|
||||
/// 是否记录请求体
|
||||
/// Whether to log the contents of request bodies.
|
||||
pub log_request_body: bool,
|
||||
/// 是否记录响应体
|
||||
/// Whether to log the contents of response bodies.
|
||||
pub log_response_body: bool,
|
||||
/// 敏感字段列表(将被脱敏)
|
||||
/// List of header names that should be redacted in logs.
|
||||
pub sensitive_fields: Vec<String>,
|
||||
}
|
||||
|
||||
@@ -48,40 +48,31 @@ impl Default for LoggingConfig {
|
||||
"token".to_string(),
|
||||
"secret".to_string(),
|
||||
"authorization".to_string(),
|
||||
"cookie".to_string(),
|
||||
"set-cookie".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 初始化日志系统
|
||||
/// Initializes the global tracing subscriber.
|
||||
pub fn init_logging(config: &LoggingConfig) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// 创建日志过滤器
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(config.level.parse().unwrap_or(tracing::Level::INFO.into()))
|
||||
.from_env_lossy();
|
||||
|
||||
// 根据配置选择日志格式
|
||||
let subscriber = tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true);
|
||||
|
||||
if config.structured {
|
||||
// 结构化日志(JSON 格式)
|
||||
tracing_subscriber::fmt()
|
||||
.json()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.init();
|
||||
subscriber.json().init();
|
||||
} else {
|
||||
// 普通文本日志
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.init();
|
||||
subscriber.init();
|
||||
}
|
||||
|
||||
tracing::info!("Logging initialized with level: {}", config.level);
|
||||
@@ -89,19 +80,19 @@ pub fn init_logging(config: &LoggingConfig) -> Result<(), Box<dyn std::error::Er
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 日志记录器
|
||||
/// Helper for logging application events with consistent metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Logger {
|
||||
config: LoggingConfig,
|
||||
}
|
||||
|
||||
impl Logger {
|
||||
/// 创建新的日志记录器
|
||||
/// Creates a new `Logger`.
|
||||
pub fn new(config: LoggingConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// 记录 HTTP 请求
|
||||
/// Logs an incoming HTTP request.
|
||||
pub fn log_request(&self, req: &axum::http::Request<axum::body::Body>, request_id: &str) {
|
||||
let method = req.method();
|
||||
let uri = req.uri();
|
||||
@@ -115,13 +106,12 @@ impl Logger {
|
||||
"HTTP request received"
|
||||
);
|
||||
|
||||
// 如果启用了请求体日志记录,记录头部
|
||||
if self.config.log_request_body {
|
||||
self.log_headers(req.headers(), "request");
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录 HTTP 响应
|
||||
/// Logs an outgoing HTTP response.
|
||||
pub fn log_response(&self, res: &axum::http::Response<axum::body::Body>, request_id: &str, duration: std::time::Duration) {
|
||||
let status = res.status();
|
||||
let version = res.version();
|
||||
@@ -134,13 +124,12 @@ impl Logger {
|
||||
"HTTP response sent"
|
||||
);
|
||||
|
||||
// 如果启用了响应体日志记录,记录头部
|
||||
if self.config.log_response_body {
|
||||
self.log_headers(res.headers(), "response");
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录头部信息(脱敏敏感字段)
|
||||
/// Logs HTTP headers, redacting sensitive information.
|
||||
fn log_headers(&self, headers: &axum::http::HeaderMap, header_type: &str) {
|
||||
let mut header_fields = std::collections::HashMap::new();
|
||||
|
||||
@@ -151,7 +140,6 @@ impl Logger {
|
||||
Err(_) => "[BINARY]".to_string(),
|
||||
};
|
||||
|
||||
// 检查是否为敏感字段
|
||||
let is_sensitive = self
|
||||
.config
|
||||
.sensitive_fields
|
||||
@@ -172,7 +160,7 @@ impl Logger {
|
||||
);
|
||||
}
|
||||
|
||||
/// 记录错误
|
||||
/// Logs an error with optional request context.
|
||||
pub fn log_error(&self, error: &impl std::error::Error, request_id: Option<&str>) {
|
||||
if let Some(id) = request_id {
|
||||
tracing::error!(
|
||||
@@ -190,7 +178,7 @@ impl Logger {
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录警告
|
||||
/// Logs a warning message.
|
||||
pub fn log_warning(&self, message: &str, context: Option<&str>) {
|
||||
if let Some(ctx) = context {
|
||||
tracing::warn!(message = %message, context = %ctx, "Warning");
|
||||
@@ -199,7 +187,7 @@ impl Logger {
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录信息
|
||||
/// Logs an informational message.
|
||||
pub fn log_info(&self, message: &str, context: Option<&str>) {
|
||||
if let Some(ctx) = context {
|
||||
tracing::info!(message = %message, context = %ctx, "Info");
|
||||
@@ -208,7 +196,7 @@ impl Logger {
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录调试信息
|
||||
/// Logs a debug message.
|
||||
pub fn log_debug(&self, message: &str, context: Option<&str>) {
|
||||
if let Some(ctx) = context {
|
||||
tracing::debug!(message = %message, context = %ctx, "Debug");
|
||||
|
||||
@@ -12,19 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Main application entry point for the trusted proxy system
|
||||
//! Main application entry point for the RustFS Trusted Proxies service.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{IntoResponse, Json},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use axum::{routing::get, Router};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::{error, info};
|
||||
use tracing::{info, Level};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod api;
|
||||
mod cloud;
|
||||
@@ -36,98 +31,86 @@ mod state;
|
||||
mod utils;
|
||||
|
||||
use api::handlers;
|
||||
use config::{AppConfig, ConfigLoader};
|
||||
use config::{AppConfig, ConfigLoader, MonitoringConfig};
|
||||
use error::AppError;
|
||||
use middleware::TrustedProxyLayer;
|
||||
use proxy::metrics::{default_proxy_metrics, ProxyMetrics};
|
||||
use proxy::metrics::default_proxy_metrics;
|
||||
use state::AppState;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), AppError> {
|
||||
// 加载环境变量
|
||||
// Load environment variables from .env file if present.
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// 从环境变量加载配置
|
||||
// Load application configuration from environment variables.
|
||||
let config = ConfigLoader::from_env_or_default();
|
||||
|
||||
// 初始化日志
|
||||
// Initialize the logging system.
|
||||
init_logging(&config.monitoring)?;
|
||||
|
||||
// 打印配置摘要
|
||||
// Print a summary of the loaded configuration.
|
||||
ConfigLoader::print_summary(&config);
|
||||
|
||||
// 初始化指标收集器
|
||||
// Initialize metrics collector if enabled.
|
||||
let metrics = if config.monitoring.metrics_enabled {
|
||||
let metrics = default_proxy_metrics(true);
|
||||
metrics.print_summary();
|
||||
Some(metrics)
|
||||
let m = default_proxy_metrics(true);
|
||||
m.print_summary();
|
||||
Some(m)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// 创建应用状态
|
||||
// Create shared application state.
|
||||
let state = AppState {
|
||||
config: Arc::new(config),
|
||||
config: Arc::new(config.clone()),
|
||||
metrics: metrics.clone(),
|
||||
};
|
||||
|
||||
// 创建可信代理中间件层
|
||||
let proxy_layer = TrustedProxyLayer::enabled(state.clone().config.proxy.clone(), metrics);
|
||||
// Initialize the trusted proxy middleware layer.
|
||||
let proxy_layer = TrustedProxyLayer::enabled(config.proxy.clone(), metrics);
|
||||
|
||||
// 创建路由
|
||||
// Build the Axum application router.
|
||||
let app = Router::new()
|
||||
// 健康检查端点
|
||||
.route("/health", get(handlers::health_check))
|
||||
// 配置查看端点
|
||||
.route("/config", get(handlers::show_config))
|
||||
// 客户端信息端点
|
||||
.route("/client-info", get(handlers::client_info))
|
||||
// 代理测试端点
|
||||
.route("/proxy-test", get(handlers::proxy_test))
|
||||
// 指标端点(如果启用)
|
||||
.route("/metrics", get(handlers::metrics))
|
||||
// 添加应用状态
|
||||
.with_state(state.clone())
|
||||
// 添加可信代理中间件
|
||||
.with_state(state)
|
||||
.layer(proxy_layer)
|
||||
// 添加追踪中间件(如果启用)
|
||||
.layer(tower_http::trace::TraceLayer::new_for_http())
|
||||
// 添加 CORS 中间件
|
||||
.layer(tower_http::cors::CorsLayer::permissive())
|
||||
// 添加压缩中间件
|
||||
.layer(tower_http::compression::CompressionLayer::new());
|
||||
|
||||
// 启动服务器
|
||||
let addr = state.config.server_addr;
|
||||
let listener = TcpListener::bind(addr).await.map_err(|e| AppError::Io(e))?;
|
||||
// Bind the TCP listener and start the server.
|
||||
let addr = config.server_addr;
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
info!("Server listening on http://{}", addr);
|
||||
info!("Available endpoints:");
|
||||
info!(" GET /health - Health check");
|
||||
info!(" GET /config - Show configuration");
|
||||
info!(" GET /client-info - Show client information");
|
||||
info!(" GET /proxy-test - Test proxy headers");
|
||||
info!(" GET /health - Service health check");
|
||||
info!(" GET /config - Current configuration summary");
|
||||
info!(" GET /client-info - Extracted client information");
|
||||
info!(" GET /proxy-test - Debugging endpoint for proxy headers");
|
||||
info!(" GET /metrics - Prometheus metrics (if enabled)");
|
||||
|
||||
axum::serve(listener, app).await.map_err(|e| AppError::Io(e))?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 初始化日志系统
|
||||
fn init_logging(monitoring_config: &config::MonitoringConfig) -> Result<(), AppError> {
|
||||
// 创建日志过滤器
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(monitoring_config.log_level.parse().unwrap_or(tracing::Level::INFO.into()))
|
||||
/// Initializes the tracing subscriber for logging.
|
||||
fn init_logging(monitoring_config: &MonitoringConfig) -> Result<(), AppError> {
|
||||
let filter = EnvFilter::builder()
|
||||
.with_default_directive(monitoring_config.log_level.parse().unwrap_or(Level::INFO.into()))
|
||||
.from_env_lossy();
|
||||
|
||||
// 根据配置选择日志格式
|
||||
let subscriber = tracing_subscriber::fmt().with_env_filter(filter);
|
||||
|
||||
if monitoring_config.structured_logging {
|
||||
// 结构化日志(JSON 格式)
|
||||
tracing_subscriber::fmt().json().with_env_filter(filter).init();
|
||||
subscriber.json().init();
|
||||
} else {
|
||||
// 普通文本日志
|
||||
tracing_subscriber::fmt().with_env_filter(filter).init();
|
||||
subscriber.init();
|
||||
}
|
||||
|
||||
info!("Logging initialized with level: {}", monitoring_config.log_level);
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Tower layer implementation for trusted proxy middleware
|
||||
//! Tower layer implementation for the trusted proxy middleware.
|
||||
|
||||
use std::sync::Arc;
|
||||
use tower::Layer;
|
||||
@@ -22,17 +22,17 @@ use crate::middleware::TrustedProxyMiddleware;
|
||||
use crate::proxy::ProxyMetrics;
|
||||
use crate::proxy::ProxyValidator;
|
||||
|
||||
/// 可信代理中间件层
|
||||
/// Tower Layer for the trusted proxy middleware.
|
||||
#[derive(Clone)]
|
||||
pub struct TrustedProxyLayer {
|
||||
/// 代理验证器
|
||||
/// The validator used to verify proxy chains.
|
||||
pub(crate) validator: Arc<ProxyValidator>,
|
||||
/// 是否启用中间件
|
||||
/// Whether the middleware is enabled.
|
||||
pub(crate) enabled: bool,
|
||||
}
|
||||
|
||||
impl TrustedProxyLayer {
|
||||
/// 创建新的中间件层
|
||||
/// Creates a new `TrustedProxyLayer`.
|
||||
pub fn new(config: TrustedProxyConfig, metrics: Option<ProxyMetrics>, enabled: bool) -> Self {
|
||||
let validator = ProxyValidator::new(config, metrics);
|
||||
|
||||
@@ -42,12 +42,12 @@ impl TrustedProxyLayer {
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建启用的中间件层
|
||||
/// Creates a new `TrustedProxyLayer` that is enabled by default.
|
||||
pub fn enabled(config: TrustedProxyConfig, metrics: Option<ProxyMetrics>) -> Self {
|
||||
Self::new(config, metrics, true)
|
||||
}
|
||||
|
||||
/// 创建禁用的中间件层
|
||||
/// Creates a new `TrustedProxyLayer` that is disabled.
|
||||
pub fn disabled() -> Self {
|
||||
Self::new(
|
||||
TrustedProxyConfig::new(Vec::new(), crate::config::ValidationMode::Lenient, true, 10, true, Vec::new()),
|
||||
@@ -56,7 +56,7 @@ impl TrustedProxyLayer {
|
||||
)
|
||||
}
|
||||
|
||||
/// 检查中间件是否启用
|
||||
/// Returns true if the middleware is enabled.
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
@@ -12,33 +12,32 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Tower service implementation for trusted proxy middleware
|
||||
//! Tower service implementation for the trusted proxy middleware.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use axum::extract::Request;
|
||||
use axum::response::Response;
|
||||
use tower::Service;
|
||||
use tracing::{debug, instrument, Span};
|
||||
|
||||
use crate::error::ProxyError;
|
||||
use crate::middleware::layer::TrustedProxyLayer;
|
||||
use crate::proxy::{ClientInfo, ProxyValidator};
|
||||
|
||||
/// 可信代理中间件服务
|
||||
/// Tower Service for the trusted proxy middleware.
|
||||
#[derive(Clone)]
|
||||
pub struct TrustedProxyMiddleware<S> {
|
||||
/// 内部服务
|
||||
/// The inner service being wrapped.
|
||||
inner: S,
|
||||
/// 代理验证器
|
||||
/// The validator used to verify proxy chains.
|
||||
validator: Arc<ProxyValidator>,
|
||||
/// 是否启用中间件
|
||||
/// Whether the middleware is enabled.
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl<S> TrustedProxyMiddleware<S> {
|
||||
/// 创建新的中间件服务
|
||||
/// Creates a new `TrustedProxyMiddleware`.
|
||||
pub fn new(inner: S, validator: Arc<ProxyValidator>, enabled: bool) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
@@ -47,7 +46,7 @@ impl<S> TrustedProxyMiddleware<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// 从层创建中间件服务
|
||||
/// Creates a new `TrustedProxyMiddleware` from a `TrustedProxyLayer`.
|
||||
pub fn from_layer(inner: S, layer: &TrustedProxyLayer) -> Self {
|
||||
Self::new(inner, layer.validator.clone(), layer.enabled)
|
||||
}
|
||||
@@ -79,57 +78,51 @@ where
|
||||
fn call(&mut self, mut req: Request) -> Self::Future {
|
||||
let span = Span::current();
|
||||
|
||||
// 如果中间件未启用,直接传递请求
|
||||
// If the middleware is disabled, pass the request through immediately.
|
||||
if !self.enabled {
|
||||
debug!("Trusted proxy middleware is disabled");
|
||||
return self.inner.call(req);
|
||||
}
|
||||
|
||||
// 记录请求开始时间
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// 提取对端地址
|
||||
// Extract the direct peer address from the request extensions.
|
||||
let peer_addr = req.extensions().get::<std::net::SocketAddr>().copied();
|
||||
|
||||
// 为 span 添加字段
|
||||
if let Some(addr) = peer_addr {
|
||||
span.record("peer.addr", addr.to_string());
|
||||
}
|
||||
|
||||
// 验证请求并提取客户端信息
|
||||
// Validate the request and extract client information.
|
||||
match self.validator.validate_request(peer_addr, req.headers()) {
|
||||
Ok(client_info) => {
|
||||
// 记录客户端信息到 span
|
||||
span.record("client.ip", client_info.real_ip.to_string());
|
||||
span.record("client.trusted", client_info.is_from_trusted_proxy);
|
||||
span.record("client.hops", client_info.proxy_hops as i64);
|
||||
|
||||
// 将客户端信息存入请求扩展
|
||||
// Insert the verified client info into the request extensions.
|
||||
req.extensions_mut().insert(client_info);
|
||||
|
||||
// 记录验证成功
|
||||
let duration = start_time.elapsed();
|
||||
debug!("Proxy validation successful in {:?}", duration);
|
||||
}
|
||||
Err(err) => {
|
||||
// 记录验证失败
|
||||
span.record("error", true);
|
||||
span.record("error.message", err.to_string());
|
||||
|
||||
// 如果是可恢复的错误,创建默认的客户端信息
|
||||
// If the error is recoverable, fallback to a direct connection info.
|
||||
if err.is_recoverable() {
|
||||
let client_info = ClientInfo::direct(
|
||||
peer_addr.unwrap_or_else(|| std::net::SocketAddr::new(std::net::IpAddr::from([0, 0, 0, 0]), 0)),
|
||||
);
|
||||
req.extensions_mut().insert(client_info);
|
||||
} else {
|
||||
// 对于不可恢复的错误,记录警告
|
||||
debug!("Unrecoverable proxy validation error: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 调用内部服务
|
||||
// Call the inner service.
|
||||
self.inner.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,226 +12,76 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Cache implementation for proxy validation
|
||||
//! High-performance cache implementation for proxy validation results using Moka.
|
||||
|
||||
use metrics::{counter, gauge, histogram};
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use moka::future::Cache;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use std::time::Duration;
|
||||
|
||||
/// 缓存条目
|
||||
#[derive(Debug, Clone)]
|
||||
struct CacheEntry {
|
||||
/// 是否可信
|
||||
is_trusted: bool,
|
||||
/// 缓存时间
|
||||
cached_at: Instant,
|
||||
/// 过期时间
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
/// IP 验证缓存
|
||||
/// Cache for storing IP validation results.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IpValidationCache {
|
||||
/// 缓存存储
|
||||
cache: Arc<RwLock<HashMap<IpAddr, CacheEntry>>>,
|
||||
/// 最大容量
|
||||
capacity: usize,
|
||||
/// 默认 TTL
|
||||
default_ttl: Duration,
|
||||
/// 是否启用
|
||||
/// The underlying Moka cache.
|
||||
cache: Cache<IpAddr, bool>,
|
||||
/// Whether the cache is enabled.
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl IpValidationCache {
|
||||
/// 创建新的缓存
|
||||
pub fn new(capacity: usize, default_ttl: Duration, enabled: bool) -> Self {
|
||||
Self {
|
||||
cache: Arc::new(RwLock::new(HashMap::with_capacity(capacity))),
|
||||
capacity,
|
||||
default_ttl,
|
||||
enabled,
|
||||
}
|
||||
/// Creates a new `IpValidationCache` using Moka.
|
||||
pub fn new(capacity: usize, ttl: Duration, enabled: bool) -> Self {
|
||||
let cache = Cache::builder()
|
||||
.max_capacity(capacity as u64)
|
||||
.time_to_live(ttl)
|
||||
.build();
|
||||
|
||||
Self { cache, enabled }
|
||||
}
|
||||
|
||||
/// 检查 IP 是否可信(带缓存)
|
||||
pub fn is_trusted(&self, ip: &IpAddr, validator: impl FnOnce(&IpAddr) -> bool) -> bool {
|
||||
// 如果缓存未启用,直接验证
|
||||
/// Checks if an IP is trusted, using the cache if available.
|
||||
pub async fn is_trusted(&self, ip: &IpAddr, validator: impl FnOnce(&IpAddr) -> bool) -> bool {
|
||||
if !self.enabled {
|
||||
return validator(ip);
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
// 检查缓存
|
||||
{
|
||||
let cache = self.cache.read();
|
||||
if let Some(entry) = cache.get(ip) {
|
||||
if now < entry.expires_at {
|
||||
// 缓存命中
|
||||
counter!("proxy.cache.hits").increment(1);
|
||||
return entry.is_trusted;
|
||||
}
|
||||
}
|
||||
// Attempt to get the result from cache.
|
||||
if let Some(is_trusted) = self.cache.get(ip).await {
|
||||
metrics::counter!("proxy.cache.hits").increment(1);
|
||||
return is_trusted;
|
||||
}
|
||||
|
||||
// 缓存未命中
|
||||
counter!("proxy.cache.misses").increment(1);
|
||||
|
||||
// 验证 IP
|
||||
// Cache miss: perform validation and update cache.
|
||||
metrics::counter!("proxy.cache.misses").increment(1);
|
||||
let is_trusted = validator(ip);
|
||||
|
||||
// 更新缓存
|
||||
self.update_cache(*ip, is_trusted, now);
|
||||
self.cache.insert(*ip, is_trusted).await;
|
||||
|
||||
is_trusted
|
||||
}
|
||||
|
||||
/// 更新缓存
|
||||
fn update_cache(&self, ip: IpAddr, is_trusted: bool, now: Instant) {
|
||||
let mut cache = self.cache.write();
|
||||
|
||||
// 检查是否需要清理(如果达到容量限制)
|
||||
if cache.len() >= self.capacity {
|
||||
self.cleanup_expired(&mut cache, now);
|
||||
|
||||
// 如果仍然满,删除最旧的条目
|
||||
if cache.len() >= self.capacity {
|
||||
self.evict_oldest(&mut cache);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加新条目
|
||||
let entry = CacheEntry {
|
||||
is_trusted,
|
||||
cached_at: now,
|
||||
expires_at: now + self.default_ttl,
|
||||
};
|
||||
|
||||
cache.insert(ip, entry);
|
||||
|
||||
// 更新指标
|
||||
gauge!("proxy.cache.size").set(cache.len() as f64);
|
||||
/// Clears all entries from the cache.
|
||||
pub async fn clear(&self) {
|
||||
self.cache.invalidate_all().await;
|
||||
metrics::gauge!("proxy.cache.size").set(0.0);
|
||||
}
|
||||
|
||||
/// 清理过期条目
|
||||
fn cleanup_expired(&self, cache: &mut HashMap<IpAddr, CacheEntry>, now: Instant) {
|
||||
let expired_keys: Vec<_> = cache
|
||||
.iter()
|
||||
.filter(|(_, entry)| now >= entry.expires_at)
|
||||
.map(|(ip, _)| *ip)
|
||||
.collect();
|
||||
|
||||
for key in expired_keys.clone() {
|
||||
cache.remove(&key);
|
||||
}
|
||||
|
||||
if !expired_keys.is_empty() {
|
||||
counter!("proxy.cache.evictions").increment(expired_keys.len() as u64);
|
||||
}
|
||||
}
|
||||
|
||||
/// 淘汰最旧的条目
|
||||
fn evict_oldest(&self, cache: &mut HashMap<IpAddr, CacheEntry>) {
|
||||
if let Some(oldest_key) = cache.iter().min_by_key(|(_, entry)| entry.cached_at).map(|(ip, _)| *ip) {
|
||||
cache.remove(&oldest_key);
|
||||
counter!("proxy.cache.evictions").increment(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// 清空缓存
|
||||
pub fn clear(&self) {
|
||||
let mut cache = self.cache.write();
|
||||
cache.clear();
|
||||
gauge!("proxy.cache.size").set(0.00);
|
||||
}
|
||||
|
||||
/// 获取缓存统计信息
|
||||
/// Returns statistics about the current state of the cache.
|
||||
pub fn stats(&self) -> CacheStats {
|
||||
let cache = self.cache.read();
|
||||
|
||||
let mut oldest = Instant::now();
|
||||
let mut newest = Instant::now();
|
||||
let mut expired_count = 0;
|
||||
|
||||
let now = Instant::now();
|
||||
for entry in cache.values() {
|
||||
if entry.cached_at < oldest {
|
||||
oldest = entry.cached_at;
|
||||
}
|
||||
if entry.cached_at > newest {
|
||||
newest = entry.cached_at;
|
||||
}
|
||||
if now >= entry.expires_at {
|
||||
expired_count += 1;
|
||||
}
|
||||
}
|
||||
let entry_count = self.cache.entry_count();
|
||||
|
||||
CacheStats {
|
||||
size: cache.len(),
|
||||
capacity: self.capacity,
|
||||
expired_count,
|
||||
oldest_age: now.duration_since(oldest),
|
||||
newest_age: now.duration_since(newest),
|
||||
size: entry_count as usize,
|
||||
// Moka doesn't expose max_capacity directly in a simple way after build,
|
||||
// but we can track it if needed.
|
||||
capacity: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// 定期清理任务
|
||||
pub async fn cleanup_task(&self, interval: Duration) {
|
||||
let mut interval_timer = tokio::time::interval(interval);
|
||||
|
||||
loop {
|
||||
interval_timer.tick().await;
|
||||
self.cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
/// 执行清理
|
||||
fn cleanup(&self) {
|
||||
let now = Instant::now();
|
||||
let mut cache = self.cache.write();
|
||||
self.cleanup_expired(&mut cache, now);
|
||||
|
||||
// 记录清理后的指标
|
||||
gauge!("proxy.cache.size").set(cache.len() as f64);
|
||||
}
|
||||
}
|
||||
|
||||
/// 缓存统计信息
|
||||
/// Statistics about the IP validation cache.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheStats {
|
||||
/// 当前缓存大小
|
||||
/// Current number of entries in the cache.
|
||||
pub size: usize,
|
||||
/// 缓存容量
|
||||
/// Maximum capacity of the cache.
|
||||
pub capacity: usize,
|
||||
/// 过期条目数量
|
||||
pub expired_count: usize,
|
||||
/// 最旧条目的年龄
|
||||
pub oldest_age: Duration,
|
||||
/// 最新条目的年龄
|
||||
pub newest_age: Duration,
|
||||
}
|
||||
|
||||
impl CacheStats {
|
||||
/// 获取缓存使用率
|
||||
pub fn usage_percentage(&self) -> f64 {
|
||||
if self.capacity == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(self.size as f64 / self.capacity as f64) * 100.0
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取命中率(需要外部跟踪命中/未命中)
|
||||
pub fn hit_rate(&self, hits: u64, misses: u64) -> f64 {
|
||||
let total = hits + misses;
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(hits as f64 / total as f64) * 100.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,48 +12,45 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Proxy chain analysis and validation
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
use tracing::{debug, trace};
|
||||
//! Proxy chain analysis and validation logic.
|
||||
|
||||
use crate::config::{TrustedProxyConfig, ValidationMode};
|
||||
use crate::error::ProxyError;
|
||||
use crate::utils::ip::is_valid_ip_address;
|
||||
use axum::http::HeaderMap;
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use tracing::trace;
|
||||
|
||||
/// 代理链分析结果
|
||||
/// Result of analyzing a proxy chain.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChainAnalysis {
|
||||
/// 客户端真实 IP
|
||||
/// The identified real client IP address.
|
||||
pub client_ip: IpAddr,
|
||||
/// 已验证的代理跳数
|
||||
/// The number of validated proxy hops.
|
||||
pub hops: usize,
|
||||
/// 是否连续
|
||||
/// Whether the proxy chain is continuous and trusted.
|
||||
pub is_continuous: bool,
|
||||
/// 警告信息
|
||||
/// List of warnings generated during analysis.
|
||||
pub warnings: Vec<String>,
|
||||
/// 使用的验证模式
|
||||
/// The validation mode used for analysis.
|
||||
pub validation_mode: ValidationMode,
|
||||
/// 可信代理部分
|
||||
/// The portion of the chain that consists of trusted proxies.
|
||||
pub trusted_chain: Vec<IpAddr>,
|
||||
}
|
||||
|
||||
/// 代理链分析器
|
||||
/// Analyzer for verifying the integrity of proxy chains.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyChainAnalyzer {
|
||||
/// 代理配置
|
||||
/// Configuration for trusted proxies.
|
||||
config: TrustedProxyConfig,
|
||||
/// 已验证的可信代理 IP 缓存(用于快速查找)
|
||||
/// Cache of trusted IP addresses for fast lookup.
|
||||
trusted_ip_cache: HashSet<IpAddr>,
|
||||
}
|
||||
|
||||
impl ProxyChainAnalyzer {
|
||||
/// 创建新的代理链分析器
|
||||
/// Creates a new `ProxyChainAnalyzer`.
|
||||
pub fn new(config: TrustedProxyConfig) -> Self {
|
||||
// 构建可信 IP 缓存
|
||||
let mut trusted_ip_cache = HashSet::new();
|
||||
|
||||
for proxy in &config.proxies {
|
||||
@@ -62,14 +59,10 @@ impl ProxyChainAnalyzer {
|
||||
trusted_ip_cache.insert(*ip);
|
||||
}
|
||||
crate::config::TrustedProxy::Cidr(network) => {
|
||||
// 对于小网络,可以缓存所有 IP
|
||||
// 这里我们只缓存/24及更小的网络的前几个IP作为示例
|
||||
// 实际生产环境中可能需要更复杂的缓存策略
|
||||
if network.prefix() >= 24 && network.prefix() <= 30 {
|
||||
// 对于小网络,我们可以缓存网络地址和广播地址之间的几个 IP
|
||||
// 这里简化处理,只缓存网络地址
|
||||
if let Some(first_ip) = network.iter().next() {
|
||||
trusted_ip_cache.insert(first_ip);
|
||||
// For small networks, cache all IPs to speed up lookups.
|
||||
if network.prefix() >= 24 {
|
||||
for ip in network.iter() {
|
||||
trusted_ip_cache.insert(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -82,7 +75,7 @@ impl ProxyChainAnalyzer {
|
||||
}
|
||||
}
|
||||
|
||||
/// 分析代理链
|
||||
/// Analyzes a proxy chain to identify the real client IP and verify trust.
|
||||
pub fn analyze_chain(
|
||||
&self,
|
||||
proxy_chain: &[IpAddr],
|
||||
@@ -91,33 +84,33 @@ impl ProxyChainAnalyzer {
|
||||
) -> Result<ChainAnalysis, ProxyError> {
|
||||
trace!("Analyzing proxy chain: {:?} with current proxy: {}", proxy_chain, current_proxy_ip);
|
||||
|
||||
// 验证 IP 地址
|
||||
// Validate all IP addresses in the chain.
|
||||
self.validate_ip_addresses(proxy_chain)?;
|
||||
|
||||
// 构建完整链(包括当前代理)
|
||||
// Construct the full chain including the direct peer.
|
||||
let mut full_chain = proxy_chain.to_vec();
|
||||
full_chain.push(current_proxy_ip);
|
||||
|
||||
// 根据验证模式分析链
|
||||
// Analyze the chain based on the configured validation mode.
|
||||
let (client_ip, trusted_chain, hops) = match self.config.validation_mode {
|
||||
ValidationMode::Lenient => self.analyze_lenient(&full_chain),
|
||||
ValidationMode::Strict => self.analyze_strict(&full_chain)?,
|
||||
ValidationMode::HopByHop => self.analyze_hop_by_hop(&full_chain),
|
||||
};
|
||||
|
||||
// 检查链连续性
|
||||
// Check for chain continuity if enabled.
|
||||
let is_continuous = if self.config.enable_chain_continuity_check {
|
||||
self.check_chain_continuity(&full_chain, &trusted_chain)
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
// 收集警告
|
||||
// Collect any warnings.
|
||||
let warnings = self.collect_warnings(&full_chain, &trusted_chain, headers);
|
||||
|
||||
// 验证客户端 IP
|
||||
// Final validation of the identified client IP.
|
||||
if !is_valid_ip_address(&client_ip) {
|
||||
return Err(ProxyError::internal(format!("Invalid client IP: {}", client_ip)));
|
||||
return Err(ProxyError::internal(format!("Invalid client IP identified: {}", client_ip)));
|
||||
}
|
||||
|
||||
Ok(ChainAnalysis {
|
||||
@@ -130,33 +123,29 @@ impl ProxyChainAnalyzer {
|
||||
})
|
||||
}
|
||||
|
||||
/// 宽松模式分析:只要最后一个代理可信,就接受整个链
|
||||
/// Lenient mode: Accepts the entire chain if the last proxy is trusted.
|
||||
fn analyze_lenient(&self, chain: &[IpAddr]) -> (IpAddr, Vec<IpAddr>, usize) {
|
||||
if chain.is_empty() {
|
||||
return (IpAddr::from([0, 0, 0, 0]), Vec::new(), 0);
|
||||
}
|
||||
|
||||
// 检查最后一个代理是否可信
|
||||
if let Some(last_proxy) = chain.last() {
|
||||
if self.is_ip_trusted(last_proxy) {
|
||||
// 整个链都可信
|
||||
let client_ip = chain.first().copied().unwrap_or(*last_proxy);
|
||||
return (client_ip, chain.to_vec(), chain.len());
|
||||
}
|
||||
}
|
||||
|
||||
// 如果最后一个代理不可信,使用链中第一个 IP 作为客户端
|
||||
let client_ip = chain.first().copied().unwrap_or(IpAddr::from([0, 0, 0, 0]));
|
||||
(client_ip, Vec::new(), 0)
|
||||
}
|
||||
|
||||
/// 严格模式分析:要求链中所有代理都可信
|
||||
/// Strict mode: Requires every IP in the chain to be trusted.
|
||||
fn analyze_strict(&self, chain: &[IpAddr]) -> Result<(IpAddr, Vec<IpAddr>, usize), ProxyError> {
|
||||
if chain.is_empty() {
|
||||
return Ok((IpAddr::from([0, 0, 0, 0]), Vec::new(), 0));
|
||||
}
|
||||
|
||||
// 检查每个代理是否都可信
|
||||
for (i, ip) in chain.iter().enumerate() {
|
||||
if !self.is_ip_trusted(ip) {
|
||||
return Err(ProxyError::chain_failed(format!("Proxy at position {} ({}) is not trusted", i, ip)));
|
||||
@@ -167,7 +156,7 @@ impl ProxyChainAnalyzer {
|
||||
Ok((client_ip, chain.to_vec(), chain.len()))
|
||||
}
|
||||
|
||||
/// 跳数模式分析:从右向左找到第一个不可信代理
|
||||
/// Hop-by-hop mode: Traverses the chain from right to left to find the first untrusted IP.
|
||||
fn analyze_hop_by_hop(&self, chain: &[IpAddr]) -> (IpAddr, Vec<IpAddr>, usize) {
|
||||
if chain.is_empty() {
|
||||
return (IpAddr::from([0, 0, 0, 0]), Vec::new(), 0);
|
||||
@@ -176,28 +165,24 @@ impl ProxyChainAnalyzer {
|
||||
let mut trusted_chain = Vec::new();
|
||||
let mut validated_hops = 0;
|
||||
|
||||
// 从右向左遍历(从离我们最近的代理开始)
|
||||
// Traverse from the most recent proxy back towards the client.
|
||||
for ip in chain.iter().rev() {
|
||||
if self.is_ip_trusted(ip) {
|
||||
trusted_chain.insert(0, *ip);
|
||||
validated_hops += 1;
|
||||
} else {
|
||||
// 找到第一个不可信代理,停止遍历
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if trusted_chain.is_empty() {
|
||||
// 没有可信代理,使用链的最后一个 IP
|
||||
let client_ip = *chain.last().unwrap();
|
||||
(client_ip, vec![client_ip], 0)
|
||||
} else {
|
||||
// 客户端 IP 是可信链的第一个 IP 之前的那个 IP
|
||||
let client_ip_index = chain.len().saturating_sub(trusted_chain.len());
|
||||
let client_ip = if client_ip_index > 0 {
|
||||
chain[client_ip_index - 1]
|
||||
} else {
|
||||
// 如果整个链都可信,使用第一个 IP
|
||||
chain[0]
|
||||
};
|
||||
|
||||
@@ -205,13 +190,12 @@ impl ProxyChainAnalyzer {
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查链连续性
|
||||
/// Verifies that the trusted portion of the chain is a continuous suffix of the full chain.
|
||||
fn check_chain_continuity(&self, full_chain: &[IpAddr], trusted_chain: &[IpAddr]) -> bool {
|
||||
if full_chain.len() <= 1 || trusted_chain.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 可信链应该是完整链的尾部连续部分
|
||||
if trusted_chain.len() > full_chain.len() {
|
||||
return false;
|
||||
}
|
||||
@@ -220,14 +204,13 @@ impl ProxyChainAnalyzer {
|
||||
expected_tail == trusted_chain
|
||||
}
|
||||
|
||||
/// 验证 IP 地址
|
||||
/// Validates that IP addresses are not unspecified, multicast, or otherwise invalid.
|
||||
fn validate_ip_addresses(&self, chain: &[IpAddr]) -> Result<(), ProxyError> {
|
||||
for ip in chain {
|
||||
if !is_valid_ip_address(ip) {
|
||||
return Err(ProxyError::IpParseError(format!("Invalid IP address in chain: {}", ip)));
|
||||
}
|
||||
|
||||
// 检查是否为特殊地址
|
||||
if ip.is_unspecified() {
|
||||
return Err(ProxyError::invalid_xff("IP address cannot be unspecified (0.0.0.0 or ::)"));
|
||||
}
|
||||
@@ -240,42 +223,35 @@ impl ProxyChainAnalyzer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查 IP 是否可信
|
||||
/// Checks if an IP address is trusted based on the configuration.
|
||||
fn is_ip_trusted(&self, ip: &IpAddr) -> bool {
|
||||
// 首先检查缓存
|
||||
if self.trusted_ip_cache.contains(ip) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 然后检查配置中的代理
|
||||
self.config.proxies.iter().any(|proxy| proxy.contains(ip))
|
||||
}
|
||||
|
||||
/// 收集警告信息
|
||||
/// Collects warnings about potential issues in the proxy chain.
|
||||
fn collect_warnings(&self, full_chain: &[IpAddr], trusted_chain: &[IpAddr], headers: &HeaderMap) -> Vec<String> {
|
||||
let mut warnings = Vec::new();
|
||||
|
||||
// 检查代理链长度
|
||||
if full_chain.len() > self.config.max_hops {
|
||||
warnings.push(format!(
|
||||
"Proxy chain length ({}) exceeds recommended maximum ({})",
|
||||
"Proxy chain length ({}) exceeds configured maximum ({})",
|
||||
full_chain.len(),
|
||||
self.config.max_hops
|
||||
));
|
||||
}
|
||||
|
||||
// 检查是否缺少必要的头部
|
||||
if trusted_chain.len() > 0 {
|
||||
if !headers.contains_key("x-forwarded-for") && !headers.contains_key("forwarded") {
|
||||
warnings.push("No proxy headers found for trusted proxy request".to_string());
|
||||
}
|
||||
if !trusted_chain.is_empty() && !headers.contains_key("x-forwarded-for") && !headers.contains_key("forwarded") {
|
||||
warnings.push("No proxy headers found for request from trusted proxy".to_string());
|
||||
}
|
||||
|
||||
// 检查是否有重复的 IP
|
||||
let mut seen_ips = HashSet::new();
|
||||
for ip in full_chain {
|
||||
if !seen_ips.insert(ip) {
|
||||
warnings.push(format!("Duplicate IP in proxy chain: {}", ip));
|
||||
warnings.push(format!("Duplicate IP address detected in proxy chain: {}", ip));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,67 +12,58 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Metrics and monitoring for proxy validation
|
||||
|
||||
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
//! Metrics and monitoring for proxy validation performance and results.
|
||||
|
||||
use crate::config::ValidationMode;
|
||||
use crate::error::ProxyError;
|
||||
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
|
||||
use std::time::Duration;
|
||||
use tracing::info;
|
||||
|
||||
/// 代理验证指标
|
||||
/// Collector for proxy validation metrics.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyMetrics {
|
||||
/// 是否启用指标
|
||||
/// Whether metrics collection is enabled.
|
||||
enabled: bool,
|
||||
/// 应用名称(用于指标标签)
|
||||
/// Application name used as a label for metrics.
|
||||
app_name: String,
|
||||
}
|
||||
|
||||
impl ProxyMetrics {
|
||||
/// 创建新的指标收集器
|
||||
/// Creates a new `ProxyMetrics` collector.
|
||||
pub fn new(app_name: &str, enabled: bool) -> Self {
|
||||
let metrics = Self {
|
||||
enabled,
|
||||
app_name: app_name.to_string(),
|
||||
};
|
||||
|
||||
// 注册指标描述
|
||||
// Register metric descriptions for Prometheus.
|
||||
metrics.register_descriptions();
|
||||
|
||||
metrics
|
||||
}
|
||||
|
||||
/// 注册指标描述
|
||||
/// Registers descriptions for all metrics.
|
||||
fn register_descriptions(&self) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
describe_counter!("proxy_validation_attempts_total", "Total number of proxy validation attempts");
|
||||
|
||||
describe_counter!("proxy_validation_success_total", "Total number of successful proxy validations");
|
||||
|
||||
describe_counter!("proxy_validation_failure_total", "Total number of failed proxy validations");
|
||||
|
||||
describe_counter!(
|
||||
"proxy_validation_failure_by_type_total",
|
||||
"Total number of failed proxy validations by error type"
|
||||
"Total number of failed proxy validations categorized by error type"
|
||||
);
|
||||
|
||||
describe_gauge!("proxy_chain_length", "Length of proxy chains being validated");
|
||||
|
||||
describe_histogram!("proxy_validation_duration_seconds", "Duration of proxy validation in seconds");
|
||||
|
||||
describe_gauge!("proxy_cache_size", "Size of the proxy validation cache");
|
||||
|
||||
describe_counter!("proxy_cache_hits_total", "Total number of cache hits");
|
||||
|
||||
describe_counter!("proxy_cache_misses_total", "Total number of cache misses");
|
||||
describe_gauge!("proxy_chain_length", "Current length of proxy chains being validated");
|
||||
describe_histogram!("proxy_validation_duration_seconds", "Time taken to validate a proxy chain in seconds");
|
||||
describe_gauge!("proxy_cache_size", "Current number of entries in the proxy validation cache");
|
||||
describe_counter!("proxy_cache_hits_total", "Total number of cache hits for proxy validation");
|
||||
describe_counter!("proxy_cache_misses_total", "Total number of cache misses for proxy validation");
|
||||
}
|
||||
|
||||
/// 记录验证尝试
|
||||
/// Increments the total number of validation attempts.
|
||||
pub fn increment_validation_attempts(&self) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
@@ -85,7 +76,7 @@ impl ProxyMetrics {
|
||||
);
|
||||
}
|
||||
|
||||
/// 记录验证成功
|
||||
/// Records a successful validation.
|
||||
pub fn record_validation_success(&self, from_trusted_proxy: bool, proxy_hops: usize, duration: Duration) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
@@ -111,7 +102,7 @@ impl ProxyMetrics {
|
||||
);
|
||||
}
|
||||
|
||||
/// 记录验证失败
|
||||
/// Records a failed validation with the specific error type.
|
||||
pub fn record_validation_failure(&self, error: &ProxyError, duration: Duration) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
@@ -152,7 +143,7 @@ impl ProxyMetrics {
|
||||
);
|
||||
}
|
||||
|
||||
/// 记录验证模式使用情况
|
||||
/// Records the validation mode currently in use.
|
||||
pub fn record_validation_mode(&self, mode: ValidationMode) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
@@ -170,7 +161,7 @@ impl ProxyMetrics {
|
||||
);
|
||||
}
|
||||
|
||||
/// 记录缓存指标
|
||||
/// Records cache performance metrics.
|
||||
pub fn record_cache_metrics(&self, hits: u64, misses: u64, size: usize) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
@@ -181,7 +172,7 @@ impl ProxyMetrics {
|
||||
gauge!("proxy_cache_size", size as f64, "app" => self.app_name.clone());
|
||||
}
|
||||
|
||||
/// 打印指标摘要
|
||||
/// Prints a summary of enabled metrics to the log.
|
||||
pub fn print_summary(&self) {
|
||||
if !self.enabled {
|
||||
info!("Metrics collection is disabled");
|
||||
@@ -202,10 +193,10 @@ impl ProxyMetrics {
|
||||
}
|
||||
}
|
||||
|
||||
/// 默认应用名称
|
||||
/// Default application name for metrics.
|
||||
const DEFAULT_APP_NAME: &str = "trusted-proxy";
|
||||
|
||||
/// 创建默认的代理指标收集器
|
||||
/// Creates a default `ProxyMetrics` collector.
|
||||
pub fn default_proxy_metrics(enabled: bool) -> ProxyMetrics {
|
||||
ProxyMetrics::new(DEFAULT_APP_NAME, enabled)
|
||||
}
|
||||
|
||||
@@ -12,42 +12,41 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Proxy validator for validating proxy chains and client information
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::time::Instant;
|
||||
//! Proxy validator for verifying proxy chains and extracting client information.
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
use tracing::{debug, trace, warn};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::time::Instant;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::config::{TrustedProxyConfig, ValidationMode};
|
||||
use crate::error::ProxyError;
|
||||
use crate::proxy::chain::ProxyChainAnalyzer;
|
||||
use crate::proxy::metrics::ProxyMetrics;
|
||||
|
||||
/// 客户端信息验证结果
|
||||
/// Information about the client extracted from the request and proxy headers.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientInfo {
|
||||
/// 真实客户端 IP 地址(已验证)
|
||||
/// The verified real IP address of the client.
|
||||
pub real_ip: IpAddr,
|
||||
/// 原始请求主机名(如果来自可信代理)
|
||||
/// The original host requested by the client (if provided by a trusted proxy).
|
||||
pub forwarded_host: Option<String>,
|
||||
/// 原始请求协议(如果来自可信代理)
|
||||
/// The original protocol (http/https) used by the client (if provided by a trusted proxy).
|
||||
pub forwarded_proto: Option<String>,
|
||||
/// 请求是否来自可信代理
|
||||
/// Whether the request was received from a trusted proxy.
|
||||
pub is_from_trusted_proxy: bool,
|
||||
/// 直接连接的代理 IP(如果经过代理)
|
||||
/// The IP address of the proxy that directly connected to this server.
|
||||
pub proxy_ip: Option<IpAddr>,
|
||||
/// 代理链长度
|
||||
/// The number of proxy hops identified in the chain.
|
||||
pub proxy_hops: usize,
|
||||
/// 验证模式
|
||||
/// The validation mode used for this request.
|
||||
pub validation_mode: ValidationMode,
|
||||
/// 验证警告信息
|
||||
/// Any warnings generated during the validation process.
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
impl ClientInfo {
|
||||
/// 创建直接连接的客户端信息(无代理)
|
||||
/// Creates a `ClientInfo` for a direct connection without any proxies.
|
||||
pub fn direct(addr: SocketAddr) -> Self {
|
||||
Self {
|
||||
real_ip: addr.ip(),
|
||||
@@ -61,7 +60,7 @@ impl ClientInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// 从可信代理创建客户端信息
|
||||
/// Creates a `ClientInfo` for a request received through a trusted proxy.
|
||||
pub fn from_trusted_proxy(
|
||||
real_ip: IpAddr,
|
||||
forwarded_host: Option<String>,
|
||||
@@ -83,7 +82,7 @@ impl ClientInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取客户端信息的字符串表示(用于日志)
|
||||
/// Returns a string representation of the client info for logging.
|
||||
pub fn to_log_string(&self) -> String {
|
||||
format!(
|
||||
"client_ip={}, proxy={:?}, hops={}, trusted={}, mode={:?}",
|
||||
@@ -92,19 +91,19 @@ impl ClientInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// 代理验证器
|
||||
/// Core validator that processes incoming requests to verify proxy chains.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyValidator {
|
||||
/// 代理配置
|
||||
/// Configuration for trusted proxies.
|
||||
config: TrustedProxyConfig,
|
||||
/// 代理链分析器
|
||||
/// Analyzer for verifying the integrity of the proxy chain.
|
||||
chain_analyzer: ProxyChainAnalyzer,
|
||||
/// 监控指标
|
||||
/// Metrics collector for observability.
|
||||
metrics: Option<ProxyMetrics>,
|
||||
}
|
||||
|
||||
impl ProxyValidator {
|
||||
/// 创建新的代理验证器
|
||||
/// Creates a new `ProxyValidator` with the given configuration and metrics.
|
||||
pub fn new(config: TrustedProxyConfig, metrics: Option<ProxyMetrics>) -> Self {
|
||||
let chain_analyzer = ProxyChainAnalyzer::new(config.clone());
|
||||
|
||||
@@ -115,53 +114,53 @@ impl ProxyValidator {
|
||||
}
|
||||
}
|
||||
|
||||
/// 验证请求并提取客户端信息
|
||||
/// Validates an incoming request and extracts client information.
|
||||
pub fn validate_request(&self, peer_addr: Option<SocketAddr>, headers: &HeaderMap) -> Result<ClientInfo, ProxyError> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// 记录验证开始
|
||||
// Record the start of the validation attempt.
|
||||
self.record_metric_start();
|
||||
|
||||
// 验证请求
|
||||
// Perform the internal validation logic.
|
||||
let result = self.validate_request_internal(peer_addr, headers);
|
||||
|
||||
// 记录验证结果
|
||||
// Record the result and duration.
|
||||
let duration = start_time.elapsed();
|
||||
self.record_metric_result(&result, duration);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// 内部验证逻辑
|
||||
/// Internal logic for request validation.
|
||||
fn validate_request_internal(&self, peer_addr: Option<SocketAddr>, headers: &HeaderMap) -> Result<ClientInfo, ProxyError> {
|
||||
// 如果没有对端地址,使用默认值
|
||||
// Fallback to unspecified address if peer address is missing.
|
||||
let peer_addr = peer_addr.unwrap_or_else(|| SocketAddr::new(IpAddr::from([0, 0, 0, 0]), 0));
|
||||
|
||||
// 检查是否来自可信代理
|
||||
// Check if the direct peer is a trusted proxy.
|
||||
if self.config.is_trusted(&peer_addr) {
|
||||
debug!("Request from trusted proxy: {}", peer_addr.ip());
|
||||
debug!("Request received from trusted proxy: {}", peer_addr.ip());
|
||||
|
||||
// 来自可信代理,解析转发头部
|
||||
// Parse and validate headers from the trusted proxy.
|
||||
self.validate_trusted_proxy_request(&peer_addr, headers)
|
||||
} else {
|
||||
// 检查是否为私有网络地址
|
||||
// Log a warning if the request is from a private network but not trusted.
|
||||
if self.config.is_private_network(&peer_addr.ip()) {
|
||||
warn!(
|
||||
"Request from private network but not trusted: {}. This might be a configuration issue.",
|
||||
"Request from private network but not trusted: {}. This might indicate a configuration issue.",
|
||||
peer_addr.ip()
|
||||
);
|
||||
}
|
||||
|
||||
// 来自不可信代理或直接连接
|
||||
// Treat as a direct connection if the peer is not trusted.
|
||||
Ok(ClientInfo::direct(peer_addr))
|
||||
}
|
||||
}
|
||||
|
||||
/// 验证来自可信代理的请求
|
||||
/// Validates a request that originated from a trusted proxy.
|
||||
fn validate_trusted_proxy_request(&self, proxy_addr: &SocketAddr, headers: &HeaderMap) -> Result<ClientInfo, ProxyError> {
|
||||
let proxy_ip = proxy_addr.ip();
|
||||
|
||||
// 优先使用 RFC 7239 Forwarded 头部(如果启用)
|
||||
// Prefer RFC 7239 "Forwarded" header if enabled, otherwise fallback to legacy headers.
|
||||
let client_info = if self.config.enable_rfc7239 {
|
||||
self.try_parse_rfc7239_headers(headers, proxy_ip)
|
||||
.unwrap_or_else(|| self.parse_legacy_headers(headers, proxy_ip))
|
||||
@@ -169,28 +168,21 @@ impl ProxyValidator {
|
||||
self.parse_legacy_headers(headers, proxy_ip)
|
||||
};
|
||||
|
||||
// 验证代理链
|
||||
// Analyze the integrity and continuity of the proxy chain.
|
||||
let chain_analysis = self
|
||||
.chain_analyzer
|
||||
.analyze_chain(&client_info.proxy_chain, proxy_ip, headers)?;
|
||||
|
||||
// 检查代理链长度
|
||||
// Enforce maximum hop limit.
|
||||
if chain_analysis.hops > self.config.max_hops {
|
||||
return Err(ProxyError::ChainTooLong(chain_analysis.hops, self.config.max_hops));
|
||||
}
|
||||
|
||||
// 检查链连续性(如果启用)
|
||||
// Enforce chain continuity if enabled.
|
||||
if self.config.enable_chain_continuity_check && !chain_analysis.is_continuous {
|
||||
return Err(ProxyError::ChainNotContinuous);
|
||||
}
|
||||
|
||||
// 创建客户端信息
|
||||
let warnings = if !chain_analysis.warnings.is_empty() {
|
||||
chain_analysis.warnings
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ClientInfo::from_trusted_proxy(
|
||||
chain_analysis.client_ip,
|
||||
client_info.forwarded_host,
|
||||
@@ -198,11 +190,11 @@ impl ProxyValidator {
|
||||
proxy_ip,
|
||||
chain_analysis.hops,
|
||||
self.config.validation_mode,
|
||||
warnings,
|
||||
chain_analysis.warnings,
|
||||
))
|
||||
}
|
||||
|
||||
/// 尝试解析 RFC 7239 Forwarded 头部
|
||||
/// Attempts to parse the RFC 7239 "Forwarded" header.
|
||||
fn try_parse_rfc7239_headers(&self, headers: &HeaderMap, proxy_ip: IpAddr) -> Option<ParsedHeaders> {
|
||||
headers
|
||||
.get("forwarded")
|
||||
@@ -210,7 +202,7 @@ impl ProxyValidator {
|
||||
.and_then(|s| Self::parse_forwarded_header(s, proxy_ip))
|
||||
}
|
||||
|
||||
/// 解析传统的代理头部
|
||||
/// Parses legacy proxy headers (X-Forwarded-For, X-Forwarded-Host, X-Forwarded-Proto).
|
||||
fn parse_legacy_headers(&self, headers: &HeaderMap, proxy_ip: IpAddr) -> ParsedHeaders {
|
||||
let forwarded_host = headers
|
||||
.get("x-forwarded-host")
|
||||
@@ -225,8 +217,8 @@ impl ProxyValidator {
|
||||
let proxy_chain = headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| Self::parse_x_forwarded_for(s))
|
||||
.unwrap_or_else(Vec::new);
|
||||
.map(Self::parse_x_forwarded_for)
|
||||
.unwrap_or_default();
|
||||
|
||||
ParsedHeaders {
|
||||
proxy_chain,
|
||||
@@ -235,16 +227,15 @@ impl ProxyValidator {
|
||||
}
|
||||
}
|
||||
|
||||
/// 解析 RFC 7239 Forwarded 头部
|
||||
/// Parses the RFC 7239 "Forwarded" header value.
|
||||
fn parse_forwarded_header(header_value: &str, proxy_ip: IpAddr) -> Option<ParsedHeaders> {
|
||||
// 简化实现:只处理第一个值
|
||||
// Simplified implementation: processes only the first entry in the header.
|
||||
let first_part = header_value.split(',').next()?.trim();
|
||||
|
||||
let mut proxy_chain = Vec::new();
|
||||
let mut forwarded_host = None;
|
||||
let mut forwarded_proto = None;
|
||||
|
||||
// 解析键值对
|
||||
for part in first_part.split(';') {
|
||||
let part = part.trim();
|
||||
if let Some((key, value)) = part.split_once('=') {
|
||||
@@ -253,7 +244,7 @@ impl ProxyValidator {
|
||||
|
||||
match key.as_str() {
|
||||
"for" => {
|
||||
// 解析客户端 IP(可能包含端口)
|
||||
// Extract IP address, ignoring port if present.
|
||||
if let Some(ip_part) = value.split(':').next() {
|
||||
if let Ok(ip) = ip_part.parse::<IpAddr>() {
|
||||
proxy_chain.push(ip);
|
||||
@@ -271,7 +262,7 @@ impl ProxyValidator {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到客户端 IP,添加代理 IP 作为备选
|
||||
// Fallback to the proxy IP if no client IP was found in the header.
|
||||
if proxy_chain.is_empty() {
|
||||
proxy_chain.push(proxy_ip);
|
||||
}
|
||||
@@ -283,28 +274,28 @@ impl ProxyValidator {
|
||||
})
|
||||
}
|
||||
|
||||
/// 解析 X-Forwarded-For 头部
|
||||
/// Parses the X-Forwarded-For header into a list of IP addresses.
|
||||
fn parse_x_forwarded_for(header_value: &str) -> Vec<IpAddr> {
|
||||
header_value
|
||||
.split(',')
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter_map(|s| {
|
||||
// 移除端口部分(如果存在)
|
||||
// Strip port if present.
|
||||
let ip_part = s.split(':').next().unwrap_or(s);
|
||||
ip_part.parse::<IpAddr>().ok()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 记录验证开始指标
|
||||
/// Records the start of a validation attempt in metrics.
|
||||
fn record_metric_start(&self) {
|
||||
if let Some(metrics) = &self.metrics {
|
||||
metrics.increment_validation_attempts();
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录验证结果指标
|
||||
/// Records the result of a validation attempt in metrics.
|
||||
fn record_metric_result(&self, result: &Result<ClientInfo, ProxyError>, duration: std::time::Duration) {
|
||||
if let Some(metrics) = &self.metrics {
|
||||
match result {
|
||||
@@ -314,7 +305,6 @@ impl ProxyValidator {
|
||||
Err(err) => {
|
||||
metrics.record_validation_failure(err, duration);
|
||||
|
||||
// 记录失败的验证(如果启用)
|
||||
if self.config.log_failed_validations {
|
||||
warn!("Proxy validation failed: {}", err);
|
||||
}
|
||||
@@ -324,13 +314,13 @@ impl ProxyValidator {
|
||||
}
|
||||
}
|
||||
|
||||
/// 解析后的头部信息
|
||||
/// Internal structure for holding parsed header information.
|
||||
#[derive(Debug, Clone)]
|
||||
struct ParsedHeaders {
|
||||
/// 代理链(客户端 IP 在第一个位置)
|
||||
/// The chain of proxy IPs (client IP is typically the first).
|
||||
proxy_chain: Vec<IpAddr>,
|
||||
/// 转发的主机名
|
||||
/// The original host requested.
|
||||
forwarded_host: Option<String>,
|
||||
/// 转发的协议
|
||||
/// The original protocol used.
|
||||
forwarded_proto: Option<String>,
|
||||
}
|
||||
|
||||
@@ -12,14 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Shared application state.
|
||||
|
||||
use crate::{AppConfig, ProxyMetrics};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// 应用状态
|
||||
/// Global application state shared across handlers and middleware.
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// 应用配置
|
||||
/// Immutable application configuration.
|
||||
pub config: Arc<AppConfig>,
|
||||
/// 代理指标收集器
|
||||
/// Optional metrics collector for observability.
|
||||
pub metrics: Option<ProxyMetrics>,
|
||||
}
|
||||
|
||||
@@ -12,22 +12,22 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! IP address utility functions
|
||||
//! IP address utility functions for validation and classification.
|
||||
|
||||
use ipnetwork::IpNetwork;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::str::FromStr;
|
||||
|
||||
/// IP 工具函数集合
|
||||
/// Collection of IP-related utility functions.
|
||||
pub struct IpUtils;
|
||||
|
||||
impl IpUtils {
|
||||
/// 检查 IP 地址是否有效
|
||||
/// Checks if an IP address is valid for general use (not unspecified, multicast, or reserved).
|
||||
pub fn is_valid_ip_address(ip: &IpAddr) -> bool {
|
||||
!ip.is_unspecified() && !ip.is_multicast() && !Self::is_reserved_ip(ip)
|
||||
}
|
||||
|
||||
/// 检查 IP 是否为保留地址
|
||||
/// Checks if an IP address belongs to a reserved range.
|
||||
pub fn is_reserved_ip(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => Self::is_reserved_ipv4(ipv4),
|
||||
@@ -35,11 +35,11 @@ impl IpUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查 IPv4 是否为保留地址
|
||||
/// Checks if an IPv4 address belongs to a reserved range.
|
||||
pub fn is_reserved_ipv4(ip: &Ipv4Addr) -> bool {
|
||||
let octets = ip.octets();
|
||||
|
||||
// 检查常见的保留地址范围
|
||||
// Check common reserved IPv4 ranges
|
||||
matches!(
|
||||
octets,
|
||||
[0, _, _, _] | // 0.0.0.0/8
|
||||
@@ -60,11 +60,11 @@ impl IpUtils {
|
||||
)
|
||||
}
|
||||
|
||||
/// 检查 IPv6 是否为保留地址
|
||||
/// Checks if an IPv6 address belongs to a reserved range.
|
||||
pub fn is_reserved_ipv6(ip: &Ipv6Addr) -> bool {
|
||||
let segments = ip.segments();
|
||||
|
||||
// 检查常见的保留地址范围
|
||||
// Check common reserved IPv6 ranges
|
||||
matches!(
|
||||
segments,
|
||||
[0, 0, 0, 0, 0, 0, 0, 0] | // ::/128
|
||||
@@ -76,7 +76,7 @@ impl IpUtils {
|
||||
)
|
||||
}
|
||||
|
||||
/// 检查 IP 是否为私有地址
|
||||
/// Checks if an IP address is a private address.
|
||||
pub fn is_private_ip(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => Self::is_private_ipv4(ipv4),
|
||||
@@ -84,7 +84,7 @@ impl IpUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查 IPv4 是否为私有地址
|
||||
/// Checks if an IPv4 address is a private address.
|
||||
pub fn is_private_ipv4(ip: &Ipv4Addr) -> bool {
|
||||
let octets = ip.octets();
|
||||
|
||||
@@ -96,7 +96,7 @@ impl IpUtils {
|
||||
)
|
||||
}
|
||||
|
||||
/// 检查 IPv6 是否为私有地址
|
||||
/// Checks if an IPv6 address is a private address.
|
||||
pub fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
|
||||
let segments = ip.segments();
|
||||
|
||||
@@ -106,7 +106,7 @@ impl IpUtils {
|
||||
)
|
||||
}
|
||||
|
||||
/// 检查 IP 是否为回环地址
|
||||
/// Checks if an IP address is a loopback address.
|
||||
pub fn is_loopback_ip(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => ipv4.is_loopback(),
|
||||
@@ -114,7 +114,7 @@ impl IpUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查 IP 是否为链路本地地址
|
||||
/// Checks if an IP address is a link-local address.
|
||||
pub fn is_link_local_ip(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => ipv4.is_link_local(),
|
||||
@@ -122,7 +122,7 @@ impl IpUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查 IP 是否为文档地址(TEST-NET)
|
||||
/// Checks if an IP address is a documentation address (TEST-NET).
|
||||
pub fn is_documentation_ip(ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => {
|
||||
@@ -141,12 +141,12 @@ impl IpUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 从字符串解析 IP 地址,支持 CIDR 表示法
|
||||
/// Parses an IP address or CIDR range from a string.
|
||||
pub fn parse_ip_or_cidr(s: &str) -> Result<IpNetwork, String> {
|
||||
IpNetwork::from_str(s).map_err(|e| format!("Failed to parse IP/CIDR '{}': {}", s, e))
|
||||
}
|
||||
|
||||
/// 从逗号分隔的字符串解析 IP 列表
|
||||
/// Parses a comma-separated list of IP addresses.
|
||||
pub fn parse_ip_list(s: &str) -> Result<Vec<IpAddr>, String> {
|
||||
let mut ips = Vec::new();
|
||||
|
||||
@@ -165,7 +165,7 @@ impl IpUtils {
|
||||
Ok(ips)
|
||||
}
|
||||
|
||||
/// 从逗号分隔的字符串解析网络列表
|
||||
/// Parses a comma-separated list of IP networks (CIDR).
|
||||
pub fn parse_network_list(s: &str) -> Result<Vec<IpNetwork>, String> {
|
||||
let mut networks = Vec::new();
|
||||
|
||||
@@ -184,12 +184,12 @@ impl IpUtils {
|
||||
Ok(networks)
|
||||
}
|
||||
|
||||
/// 检查 IP 是否在给定的网络列表中
|
||||
/// Checks if an IP address is contained within any of the given networks.
|
||||
pub fn ip_in_networks(ip: &IpAddr, networks: &[IpNetwork]) -> bool {
|
||||
networks.iter().any(|network| network.contains(*ip))
|
||||
}
|
||||
|
||||
/// 获取 IP 地址的类型描述
|
||||
/// Returns a string description of the IP address type.
|
||||
pub fn get_ip_type(ip: &IpAddr) -> &'static str {
|
||||
if Self::is_private_ip(ip) {
|
||||
"private"
|
||||
@@ -206,16 +206,16 @@ impl IpUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 将 IP 地址转换为规范形式
|
||||
/// Returns the canonical string representation of an IP address.
|
||||
pub fn canonical_ip(ip: &IpAddr) -> String {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => ipv4.to_string(),
|
||||
IpAddr::V6(ipv6) => {
|
||||
// 压缩 IPv6 地址
|
||||
// Compress IPv6 address
|
||||
let mut result = String::new();
|
||||
let segments = ipv6.segments();
|
||||
|
||||
// 查找最长的连续零段
|
||||
// Find the longest sequence of zero segments
|
||||
let mut longest_start = 0;
|
||||
let mut longest_len = 0;
|
||||
let mut current_start = 0;
|
||||
@@ -241,20 +241,21 @@ impl IpUtils {
|
||||
longest_len = current_len;
|
||||
}
|
||||
|
||||
// 格式化为字符串
|
||||
for mut i in 0..8 {
|
||||
// Format as string
|
||||
let mut i = 0;
|
||||
while i < 8 {
|
||||
if i == longest_start && longest_len > 1 {
|
||||
result.push_str("::");
|
||||
i += longest_len - 1;
|
||||
} else if i == longest_start && longest_len == 1 {
|
||||
result.push('0');
|
||||
i += longest_len;
|
||||
if i == 8 {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if i > 0 && i != longest_start {
|
||||
if i > 0 && (i != longest_start + longest_len || longest_len <= 1) {
|
||||
result.push(':');
|
||||
}
|
||||
if segments[i] != 0 || (i == 7 && result.is_empty()) {
|
||||
result.push_str(&format!("{:x}", segments[i]));
|
||||
}
|
||||
result.push_str(&format!("{:x}", segments[i]));
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,3 +264,8 @@ impl IpUtils {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to check if an IP address is valid.
|
||||
pub fn is_valid_ip_address(ip: &IpAddr) -> bool {
|
||||
IpUtils::is_valid_ip_address(ip)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Utility functions and helpers
|
||||
//! Utility functions and helpers for the trusted proxy system.
|
||||
|
||||
mod ip;
|
||||
mod validation;
|
||||
@@ -20,32 +20,32 @@ mod validation;
|
||||
pub use ip::*;
|
||||
pub use validation::*;
|
||||
|
||||
/// 工具函数集合
|
||||
/// Collection of general utility functions.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Utils;
|
||||
|
||||
impl Utils {
|
||||
/// 生成追踪 ID
|
||||
/// Generates a unique trace ID.
|
||||
pub fn generate_trace_id() -> String {
|
||||
format!("trace-{}", uuid::Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// 生成 Span ID
|
||||
/// Generates a unique span ID.
|
||||
pub fn generate_span_id() -> String {
|
||||
format!("span-{}", uuid::Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// 安全的将字符串转换为 usize
|
||||
/// Safely parses a string into a `usize`, returning a default value on failure.
|
||||
pub fn safe_parse_usize(s: &str, default: usize) -> usize {
|
||||
s.parse().unwrap_or(default)
|
||||
}
|
||||
|
||||
/// 安全的将字符串转换为 u64
|
||||
/// Safely parses a string into a `u64`, returning a default value on failure.
|
||||
pub fn safe_parse_u64(s: &str, default: u64) -> u64 {
|
||||
s.parse().unwrap_or(default)
|
||||
}
|
||||
|
||||
/// 安全的将字符串转换为布尔值
|
||||
/// Safely parses a string into a boolean, returning a default value on failure.
|
||||
pub fn safe_parse_bool(s: &str, default: bool) -> bool {
|
||||
match s.to_lowercase().as_str() {
|
||||
"true" | "1" | "yes" | "on" => true,
|
||||
@@ -54,7 +54,7 @@ impl Utils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 格式化持续时间
|
||||
/// Formats a `Duration` into a human-readable string.
|
||||
pub fn format_duration(duration: std::time::Duration) -> String {
|
||||
if duration.as_secs() > 0 {
|
||||
format!("{:.2}s", duration.as_secs_f64())
|
||||
@@ -67,22 +67,22 @@ impl Utils {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取当前时间戳
|
||||
/// Returns the current UTC timestamp in RFC 3339 format.
|
||||
pub fn current_timestamp() -> String {
|
||||
chrono::Utc::now().to_rfc3339()
|
||||
}
|
||||
|
||||
/// 安全的获取环境变量
|
||||
/// Safely retrieves an environment variable.
|
||||
pub fn get_env_var(key: &str) -> Option<String> {
|
||||
std::env::var(key).ok()
|
||||
}
|
||||
|
||||
/// 获取环境变量,如果不存在则使用默认值
|
||||
/// Retrieves an environment variable or returns a default value if not set.
|
||||
pub fn get_env_var_or(key: &str, default: &str) -> String {
|
||||
std::env::var(key).unwrap_or_else(|_| default.to_string())
|
||||
}
|
||||
|
||||
/// 检查环境变量是否存在
|
||||
/// Checks if an environment variable is set.
|
||||
pub fn has_env_var(key: &str) -> bool {
|
||||
std::env::var(key).is_ok()
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Validation utility functions
|
||||
//! Validation utility functions for various data types.
|
||||
|
||||
use http::HeaderMap;
|
||||
use lazy_static::lazy_static;
|
||||
@@ -20,11 +20,11 @@ use regex::Regex;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// 验证工具函数集合
|
||||
/// Collection of validation utility functions.
|
||||
pub struct ValidationUtils;
|
||||
|
||||
impl ValidationUtils {
|
||||
/// 验证电子邮件地址
|
||||
/// Validates an email address format.
|
||||
pub fn is_valid_email(email: &str) -> bool {
|
||||
lazy_static! {
|
||||
static ref EMAIL_REGEX: Regex =
|
||||
@@ -34,7 +34,7 @@ impl ValidationUtils {
|
||||
EMAIL_REGEX.is_match(email)
|
||||
}
|
||||
|
||||
/// 验证 URL
|
||||
/// Validates a URL format.
|
||||
pub fn is_valid_url(url: &str) -> bool {
|
||||
lazy_static! {
|
||||
static ref URL_REGEX: Regex =
|
||||
@@ -44,22 +44,19 @@ impl ValidationUtils {
|
||||
URL_REGEX.is_match(url)
|
||||
}
|
||||
|
||||
/// 验证 X-Forwarded-For 头部
|
||||
/// Validates the format of an X-Forwarded-For header value.
|
||||
pub fn validate_x_forwarded_for(header_value: &str) -> bool {
|
||||
if header_value.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 分割 IP 地址
|
||||
let ips: Vec<&str> = header_value.split(',').map(|s| s.trim()).collect();
|
||||
|
||||
// 检查每个 IP 地址
|
||||
for ip_str in ips {
|
||||
if ip_str.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 移除端口部分(如果存在)
|
||||
let ip_part = ip_str.split(':').next().unwrap_or(ip_str);
|
||||
|
||||
if IpAddr::from_str(ip_part).is_err() {
|
||||
@@ -70,20 +67,18 @@ impl ValidationUtils {
|
||||
true
|
||||
}
|
||||
|
||||
/// 验证 Forwarded 头部(RFC 7239)
|
||||
/// Validates the format of an RFC 7239 Forwarded header value.
|
||||
pub fn validate_forwarded_header(header_value: &str) -> bool {
|
||||
if header_value.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 简化的验证:检查基本格式
|
||||
let parts: Vec<&str> = header_value.split(';').collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查每个部分是否包含等号
|
||||
for part in parts {
|
||||
let part = part.trim();
|
||||
if !part.contains('=') {
|
||||
@@ -94,7 +89,7 @@ impl ValidationUtils {
|
||||
true
|
||||
}
|
||||
|
||||
/// 验证 IP 地址是否在允许的范围内
|
||||
/// Checks if an IP address is within any of the specified CIDR ranges.
|
||||
pub fn validate_ip_in_range(ip: &IpAddr, cidr_ranges: &[String]) -> bool {
|
||||
for cidr in cidr_ranges {
|
||||
if let Ok(network) = ipnetwork::IpNetwork::from_str(cidr) {
|
||||
@@ -107,16 +102,14 @@ impl ValidationUtils {
|
||||
false
|
||||
}
|
||||
|
||||
/// 验证头部是否包含恶意内容
|
||||
/// Validates a header value for security (length and control characters).
|
||||
pub fn validate_header_value(value: &str) -> bool {
|
||||
// 检查是否包含控制字符(除了水平制表符)
|
||||
for c in value.chars() {
|
||||
if c.is_control() && c != '\t' && c != '\n' && c != '\r' {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 检查长度限制(防止头部过大攻击)
|
||||
if value.len() > 8192 {
|
||||
return false;
|
||||
}
|
||||
@@ -124,53 +117,47 @@ impl ValidationUtils {
|
||||
true
|
||||
}
|
||||
|
||||
/// 验证整个头部映射
|
||||
/// Validates an entire HeaderMap for security.
|
||||
pub fn validate_headers(headers: &HeaderMap) -> bool {
|
||||
for (name, value) in headers {
|
||||
// 检查头部名称
|
||||
let name_str = name.as_str();
|
||||
if name_str.len() > 256 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查头部值
|
||||
if let Ok(value_str) = value.to_str() {
|
||||
if !Self::validate_header_value(value_str) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// 无法转换为字符串,可能包含二进制数据
|
||||
if value.len() > 8192 {
|
||||
return false;
|
||||
}
|
||||
} else if value.len() > 8192 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// 验证端口号
|
||||
/// Validates a port number.
|
||||
pub fn validate_port(port: u16) -> bool {
|
||||
port > 0 && port <= 65535
|
||||
port > 0
|
||||
}
|
||||
|
||||
/// 验证 CIDR 表示法
|
||||
/// Validates a CIDR notation string.
|
||||
pub fn validate_cidr(cidr: &str) -> bool {
|
||||
ipnetwork::IpNetwork::from_str(cidr).is_ok()
|
||||
}
|
||||
|
||||
/// 验证代理链长度
|
||||
/// Validates the length of a proxy chain.
|
||||
pub fn validate_proxy_chain_length(chain: &[IpAddr], max_length: usize) -> bool {
|
||||
chain.len() <= max_length
|
||||
}
|
||||
|
||||
/// 验证代理链是否连续
|
||||
/// Validates that a proxy chain does not contain duplicate adjacent IPs.
|
||||
pub fn validate_proxy_chain_continuity(chain: &[IpAddr]) -> bool {
|
||||
if chain.len() < 2 {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 检查是否有重复的相邻 IP
|
||||
for i in 1..chain.len() {
|
||||
if chain[i] == chain[i - 1] {
|
||||
return false;
|
||||
@@ -180,24 +167,25 @@ impl ValidationUtils {
|
||||
true
|
||||
}
|
||||
|
||||
/// 验证字符串是否只包含安全字符
|
||||
/// Checks if a string contains only safe characters for use in URLs or headers.
|
||||
pub fn is_safe_string(s: &str) -> bool {
|
||||
// 允许的字符:字母、数字、基本标点符号
|
||||
let safe_pattern = Regex::new(r"^[a-zA-Z0-9\-._~:/?#\[\]@!$&'()*+,;=]+$").unwrap();
|
||||
safe_pattern.is_match(s)
|
||||
lazy_static! {
|
||||
static ref SAFE_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9\-._~:/?#\[\]@!$&'()*+,;=]+$").unwrap();
|
||||
}
|
||||
SAFE_REGEX.is_match(s)
|
||||
}
|
||||
|
||||
/// 验证速率限制参数
|
||||
/// Validates rate limiting parameters.
|
||||
pub fn validate_rate_limit_params(requests: u32, period_seconds: u64) -> bool {
|
||||
requests > 0 && requests <= 10000 && period_seconds > 0 && period_seconds <= 86400
|
||||
}
|
||||
|
||||
/// 验证缓存参数
|
||||
/// Validates cache configuration parameters.
|
||||
pub fn validate_cache_params(capacity: usize, ttl_seconds: u64) -> bool {
|
||||
capacity > 0 && capacity <= 1000000 && ttl_seconds > 0 && ttl_seconds <= 86400
|
||||
}
|
||||
|
||||
/// 脱敏敏感数据
|
||||
/// Redacts sensitive information from a string based on provided patterns.
|
||||
pub fn mask_sensitive_data(data: &str, sensitive_patterns: &[&str]) -> String {
|
||||
let mut result = data.to_string();
|
||||
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! API integration tests
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::{extract::State, routing::get, Router};
|
||||
use serde_json::{json, Value};
|
||||
use tower::ServiceExt;
|
||||
|
||||
use crate::config::{AppConfig, TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
use crate::middleware::TrustedProxyLayer;
|
||||
use crate::AppState;
|
||||
|
||||
fn create_test_app_state() -> AppState {
|
||||
let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())];
|
||||
|
||||
let proxy_config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]);
|
||||
|
||||
let config = AppConfig::new(
|
||||
proxy_config,
|
||||
crate::config::CacheConfig::default(),
|
||||
crate::config::MonitoringConfig::default(),
|
||||
crate::config::CloudConfig::default(),
|
||||
"127.0.0.1:3000".parse().unwrap(),
|
||||
);
|
||||
|
||||
AppState {
|
||||
config: Arc::new(config),
|
||||
metrics: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_api_router() -> Router {
|
||||
let state = create_test_app_state();
|
||||
let proxy_layer = TrustedProxyLayer::enabled(state.config.proxy.clone(), None);
|
||||
|
||||
Router::new()
|
||||
.route("/health", get(health_check))
|
||||
.route("/config", get(show_config))
|
||||
.with_state(state)
|
||||
.layer(proxy_layer)
|
||||
}
|
||||
|
||||
async fn health_check() -> axum::response::Json<Value> {
|
||||
axum::response::Json(json!({
|
||||
"status": "healthy",
|
||||
"service": "trusted-proxy-test"
|
||||
}))
|
||||
}
|
||||
|
||||
async fn show_config(State(state): State<AppState>) -> axum::response::Json<Value> {
|
||||
axum::response::Json(json!({
|
||||
"server": state.config.server_addr.to_string(),
|
||||
"proxy": {
|
||||
"trusted_networks": state.config.proxy.proxies.len(),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_check_endpoint() {
|
||||
let app = create_test_api_router();
|
||||
|
||||
let request = axum::http::Request::builder().uri("/health").body(Body::empty()).unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
let json: Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
assert_eq!(json["status"], "healthy");
|
||||
assert_eq!(json["service"], "trusted-proxy-test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_endpoint() {
|
||||
let app = create_test_api_router();
|
||||
|
||||
let request = axum::http::Request::builder().uri("/config").body(Body::empty()).unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
let json: Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
assert_eq!(json["server"], "127.0.0.1:3000");
|
||||
assert_eq!(json["proxy"]["trusted_networks"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_headers_in_api() {
|
||||
let state = create_test_app_state();
|
||||
let proxy_layer = TrustedProxyLayer::enabled(state.config.proxy.clone(), None);
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/client-test",
|
||||
get(|req: axum::extract::Request| async move {
|
||||
let client_info = req.extensions().get::<crate::middleware::ClientInfo>();
|
||||
match client_info {
|
||||
Some(info) => axum::response::Json(json!({
|
||||
"client_ip": info.real_ip.to_string(),
|
||||
"trusted": info.is_from_trusted_proxy
|
||||
})),
|
||||
None => axum::response::Json(json!({
|
||||
"error": "No client info"
|
||||
})),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.with_state(state)
|
||||
.layer(proxy_layer);
|
||||
|
||||
// 测试带代理头部的请求
|
||||
let request = axum::http::Request::builder()
|
||||
.uri("/client-test")
|
||||
.header("X-Forwarded-For", "203.0.113.195")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
let json: Value = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
// 由于请求来自 127.0.0.1(可信代理),应该解析 X-Forwarded-For
|
||||
if json.get("client_ip").is_some() {
|
||||
let client_ip = json["client_ip"].as_str().unwrap();
|
||||
// 可能是 203.0.113.195 或 127.0.0.1,取决于中间件如何配置
|
||||
assert!(client_ip == "203.0.113.195" || client_ip == "127.0.0.1");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_missing_endpoint() {
|
||||
let app = create_test_api_router();
|
||||
|
||||
let request = axum::http::Request::builder().uri("/not-found").body(Body::empty()).unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// 应该返回 404
|
||||
assert_eq!(response.status(), 404);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_without_proxy_layer() {
|
||||
// 创建没有代理中间件的路由
|
||||
let app = Router::new().route("/simple", get(|| async { "OK" }));
|
||||
|
||||
let request = axum::http::Request::builder().uri("/simple").body(Body::empty()).unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
assert_eq!(String::from_utf8(body.to_vec()).unwrap(), "OK");
|
||||
}
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Cloud metadata integration tests
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
use crate::cloud::detector::CloudDetector;
|
||||
use crate::cloud::metadata::{AwsMetadataFetcher, AzureMetadataFetcher, GcpMetadataFetcher};
|
||||
use crate::cloud::ranges::{CloudflareIpRanges, GoogleCloudIpRanges};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cloud_detector_disabled() {
|
||||
let detector = CloudDetector::new(
|
||||
false, // 禁用
|
||||
std::time::Duration::from_secs(1),
|
||||
None,
|
||||
);
|
||||
|
||||
let provider = detector.detect_provider();
|
||||
assert!(provider.is_none());
|
||||
|
||||
let ranges = detector.fetch_trusted_ranges().await;
|
||||
assert!(ranges.is_ok());
|
||||
assert!(ranges.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cloud_detector_forced_provider() {
|
||||
let detector = CloudDetector::new(true, std::time::Duration::from_secs(1), Some("aws".to_string()));
|
||||
|
||||
let provider = detector.detect_provider();
|
||||
assert!(provider.is_some());
|
||||
assert_eq!(provider.unwrap().name(), "aws");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aws_metadata_fetcher() {
|
||||
let fetcher = AwsMetadataFetcher::new();
|
||||
|
||||
// 测试提供者名称
|
||||
assert_eq!(fetcher.provider_name(), "aws");
|
||||
|
||||
// 由于不在 AWS 环境中,这些调用应该失败或返回默认值
|
||||
let network_result = fetcher.fetch_network_cidrs().await;
|
||||
// 可能返回默认范围或错误
|
||||
assert!(network_result.is_ok());
|
||||
|
||||
let public_result = fetcher.fetch_public_ip_ranges().await;
|
||||
// 可能从 API 获取或返回空列表
|
||||
assert!(public_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_azure_metadata_fetcher() {
|
||||
let fetcher = AzureMetadataFetcher::new();
|
||||
|
||||
// 测试提供者名称
|
||||
assert_eq!(fetcher.provider_name(), "azure");
|
||||
|
||||
// 由于不在 Azure 环境中,这些调用应该返回默认值
|
||||
let network_result = fetcher.fetch_network_cidrs().await;
|
||||
assert!(network_result.is_ok());
|
||||
|
||||
let public_result = fetcher.fetch_public_ip_ranges().await;
|
||||
assert!(public_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gcp_metadata_fetcher() {
|
||||
let fetcher = GcpMetadataFetcher::new();
|
||||
|
||||
// 测试提供者名称
|
||||
assert_eq!(fetcher.provider_name(), "gcp");
|
||||
|
||||
// 由于不在 GCP 环境中,这些调用应该返回默认值
|
||||
let network_result = fetcher.fetch_network_cidrs().await;
|
||||
assert!(network_result.is_ok());
|
||||
|
||||
let public_result = fetcher.fetch_public_ip_ranges().await;
|
||||
assert!(public_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cloudflare_ip_ranges_static() {
|
||||
let ranges = CloudflareIpRanges::fetch().await;
|
||||
|
||||
assert!(ranges.is_ok());
|
||||
|
||||
let networks = ranges.unwrap();
|
||||
assert!(!networks.is_empty());
|
||||
|
||||
// 检查是否包含预期的范围
|
||||
let has_ipv4 = networks
|
||||
.iter()
|
||||
.any(|n| n.to_string().contains("103.21.244.0/22") || n.to_string().contains("198.41.128.0/17"));
|
||||
|
||||
let has_ipv6 = networks
|
||||
.iter()
|
||||
.any(|n| n.to_string().contains("2400:cb00::/32") || n.to_string().contains("2606:4700::/32"));
|
||||
|
||||
assert!(has_ipv4 || has_ipv6);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_google_cloud_ip_ranges_api_mock() {
|
||||
// 创建模拟服务器
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
// 模拟 Google IP 范围 API 响应
|
||||
let mock_response = r#"
|
||||
{
|
||||
"prefixes": [
|
||||
{"ipv4Prefix": "8.34.208.0/20"},
|
||||
{"ipv4Prefix": "8.35.192.0/20"},
|
||||
{"ipv6Prefix": "2001:4860::/32"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/ipranges/cloud.json"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string(mock_response))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
// 创建自定义客户端指向模拟服务器
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(2))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let url = format!("{}/ipranges/cloud.json", mock_server.uri());
|
||||
|
||||
let response = client.get(&url).send().await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = response.text().await.unwrap();
|
||||
assert!(body.contains("8.34.208.0/20"));
|
||||
assert!(body.contains("2001:4860::/32"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cloud_detector_try_all_providers() {
|
||||
let detector = CloudDetector::new(true, std::time::Duration::from_secs(2), None);
|
||||
|
||||
// 在测试环境中,所有提供者都应该失败或返回空列表
|
||||
let result = detector.try_all_providers().await;
|
||||
|
||||
// 应该成功返回(即使是空列表)
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_network_parsing() {
|
||||
// 测试 CIDR 解析
|
||||
let cidr = ipnetwork::IpNetwork::from_str("192.168.1.0/24");
|
||||
assert!(cidr.is_ok());
|
||||
|
||||
let network = cidr.unwrap();
|
||||
assert_eq!(network.prefix(), 24);
|
||||
|
||||
// 测试 IP 包含检查
|
||||
let ip: std::net::IpAddr = "192.168.1.100".parse().unwrap();
|
||||
assert!(network.contains(ip));
|
||||
|
||||
let ip_outside: std::net::IpAddr = "192.168.2.100".parse().unwrap();
|
||||
assert!(!network.contains(ip_outside));
|
||||
|
||||
// 测试 IPv6 CIDR
|
||||
let ipv6_cidr = ipnetwork::IpNetwork::from_str("2001:db8::/32");
|
||||
assert!(ipv6_cidr.is_ok());
|
||||
|
||||
let ipv6_network = ipv6_cidr.unwrap();
|
||||
assert_eq!(ipv6_network.prefix(), 32);
|
||||
}
|
||||
}
|
||||
@@ -1,188 +0,0 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Proxy system integration tests
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::error::Request;
|
||||
use axum::body::Body;
|
||||
use axum::{extract::Request, routing::get, Router};
|
||||
use tower::ServiceExt;
|
||||
|
||||
use crate::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
use crate::middleware::{ClientInfo, TrustedProxyLayer};
|
||||
|
||||
fn create_test_router() -> Router {
|
||||
let proxies = vec![
|
||||
TrustedProxy::Single("127.0.0.1".parse().unwrap()),
|
||||
TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()),
|
||||
];
|
||||
|
||||
let config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]);
|
||||
|
||||
let proxy_layer = TrustedProxyLayer::enabled(config, None);
|
||||
|
||||
Router::new()
|
||||
.route(
|
||||
"/test",
|
||||
get(|req: Request| async move {
|
||||
let client_info = req.extensions().get::<ClientInfo>();
|
||||
match client_info {
|
||||
Some(info) => {
|
||||
format!("IP: {}, Trusted: {}, Hops: {}", info.real_ip, info.is_from_trusted_proxy, info.proxy_hops)
|
||||
}
|
||||
None => "No client info".to_string(),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.layer(proxy_layer)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_direct_connection() {
|
||||
let app = create_test_router();
|
||||
|
||||
// 模拟直接连接(无代理头部)
|
||||
let request = axum::http::Request::builder().uri("/test").body(Body::empty()).unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
let body_str = String::from_utf8(body.to_vec()).unwrap();
|
||||
|
||||
// 应该显示直接连接的 IP(在测试环境中可能是 0.0.0.0)
|
||||
assert!(body_str.contains("IP:"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_trusted_proxy_with_xff() {
|
||||
let app = create_test_router();
|
||||
|
||||
// 模拟来自可信代理的请求
|
||||
let request = axum::http::Request::builder()
|
||||
.uri("/test")
|
||||
.header("X-Forwarded-For", "203.0.113.195, 10.0.1.100")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
let body_str = String::from_utf8(body.to_vec()).unwrap();
|
||||
|
||||
// 应该显示客户端 IP (203.0.113.195)
|
||||
assert!(body_str.contains("203.0.113.195"));
|
||||
assert!(body_str.contains("Trusted: true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_untrusted_proxy_with_xff() {
|
||||
let app = create_test_router();
|
||||
|
||||
// 模拟来自不可信代理的请求
|
||||
let request = axum::http::Request::builder()
|
||||
.uri("/test")
|
||||
.header("X-Forwarded-For", "203.0.113.195")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
let body_str = String::from_utf8(body.to_vec()).unwrap();
|
||||
|
||||
// 由于请求不是来自可信代理,X-Forwarded-For 应该被忽略
|
||||
// 应该显示直接连接的 IP
|
||||
assert!(!body_str.contains("203.0.113.195"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_chain_too_long() {
|
||||
let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())];
|
||||
|
||||
let config = TrustedProxyConfig::new(
|
||||
proxies,
|
||||
ValidationMode::Strict,
|
||||
true,
|
||||
3, // 最大 3 跳
|
||||
true,
|
||||
vec![],
|
||||
);
|
||||
|
||||
let proxy_layer = TrustedProxyLayer::enabled(config, None);
|
||||
|
||||
let app = Router::new().route("/test", get(|| async { "OK" })).layer(proxy_layer);
|
||||
|
||||
// 模拟超长代理链
|
||||
let xff_value = (0..5).map(|i| format!("10.0.{}.1", i)).collect::<Vec<_>>().join(", ");
|
||||
|
||||
let request = axum::http::Request::builder()
|
||||
.uri("/test")
|
||||
.header("X-Forwarded-For", xff_value)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// 由于代理链太长,验证应该失败
|
||||
// 注意:中间件可能会降级处理,而不是直接拒绝
|
||||
assert_eq!(response.status(), 200);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rfc7239_forwarded_header() {
|
||||
let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())];
|
||||
|
||||
let config = TrustedProxyConfig::new(
|
||||
proxies,
|
||||
ValidationMode::HopByHop,
|
||||
true, // 启用 RFC 7239
|
||||
10,
|
||||
true,
|
||||
vec![],
|
||||
);
|
||||
|
||||
let proxy_layer = TrustedProxyLayer::enabled(config, None);
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/test",
|
||||
get(|req: Request| async move {
|
||||
let client_info = req.extensions().get::<ClientInfo>().unwrap();
|
||||
format!("IP: {}", client_info.real_ip)
|
||||
}),
|
||||
)
|
||||
.layer(proxy_layer);
|
||||
|
||||
// 模拟使用 RFC 7239 Forwarded 头部的请求
|
||||
let request = axum::http::Request::builder()
|
||||
.uri("/test")
|
||||
.header("Forwarded", r#"for=192.0.2.60;proto=https;by=203.0.113.43"#)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
let body_str = String::from_utf8(body.to_vec()).unwrap();
|
||||
|
||||
// 应该解析 RFC 7239 头部
|
||||
assert!(body_str.contains("192.0.2.60"));
|
||||
}
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Configuration module unit tests
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::config::env::{DEFAULT_TRUSTED_PROXIES, ENV_TRUSTED_PROXIES};
|
||||
use crate::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
|
||||
#[test]
|
||||
fn test_config_loader_default() {
|
||||
// 清理环境变量
|
||||
std::env::remove_var(ENV_TRUSTED_PROXIES);
|
||||
|
||||
let config = ConfigLoader::from_env_or_default();
|
||||
|
||||
// 验证默认值
|
||||
assert_eq!(config.server_addr.port(), 3000);
|
||||
assert!(!config.proxy.proxies.is_empty());
|
||||
assert_eq!(config.proxy.validation_mode, ValidationMode::HopByHop);
|
||||
assert!(config.proxy.enable_rfc7239);
|
||||
assert_eq!(config.proxy.max_hops, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_loader_env_vars() {
|
||||
// 设置环境变量
|
||||
std::env::set_var(ENV_TRUSTED_PROXIES, "192.168.1.0/24,10.0.0.0/8");
|
||||
std::env::set_var("TRUSTED_PROXY_VALIDATION_MODE", "strict");
|
||||
std::env::set_var("TRUSTED_PROXY_MAX_HOPS", "5");
|
||||
std::env::set_var("SERVER_PORT", "8080");
|
||||
|
||||
let config = ConfigLoader::from_env();
|
||||
|
||||
if let Ok(config) = config {
|
||||
assert_eq!(config.server_addr.port(), 8080);
|
||||
assert_eq!(config.proxy.validation_mode, ValidationMode::Strict);
|
||||
assert_eq!(config.proxy.max_hops, 5);
|
||||
|
||||
// 清理环境变量
|
||||
std::env::remove_var(ENV_TRUSTED_PROXIES);
|
||||
std::env::remove_var("TRUSTED_PROXY_VALIDATION_MODE");
|
||||
std::env::remove_var("TRUSTED_PROXY_MAX_HOPS");
|
||||
std::env::remove_var("SERVER_PORT");
|
||||
} else {
|
||||
panic!("Failed to load config from env");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trusted_proxy_config() {
|
||||
let proxies = vec![
|
||||
TrustedProxy::Single("192.168.1.1".parse().unwrap()),
|
||||
TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()),
|
||||
];
|
||||
|
||||
let config = TrustedProxyConfig::new(proxies.clone(), ValidationMode::Strict, true, 10, true, vec![]);
|
||||
|
||||
assert_eq!(config.proxies.len(), 2);
|
||||
assert_eq!(config.validation_mode, ValidationMode::Strict);
|
||||
assert!(config.enable_rfc7239);
|
||||
assert_eq!(config.max_hops, 10);
|
||||
assert!(config.enable_chain_continuity_check);
|
||||
|
||||
// 测试 IP 检查
|
||||
let test_ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
let test_socket_addr = std::net::SocketAddr::new(test_ip, 8080);
|
||||
|
||||
assert!(config.is_trusted(&test_socket_addr));
|
||||
|
||||
let test_ip2: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
let test_socket_addr2 = std::net::SocketAddr::new(test_ip2, 8080);
|
||||
|
||||
assert!(config.is_trusted(&test_socket_addr2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_mode_from_str() {
|
||||
assert_eq!(ValidationMode::from_str("lenient").unwrap(), ValidationMode::Lenient);
|
||||
assert_eq!(ValidationMode::from_str("strict").unwrap(), ValidationMode::Strict);
|
||||
assert_eq!(ValidationMode::from_str("hop_by_hop").unwrap(), ValidationMode::HopByHop);
|
||||
|
||||
// 测试无效值
|
||||
assert!(ValidationMode::from_str("invalid").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trusted_proxy_contains() {
|
||||
// 测试单个 IP
|
||||
let single_proxy = TrustedProxy::Single("192.168.1.1".parse().unwrap());
|
||||
let test_ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
let test_ip2: IpAddr = "192.168.1.2".parse().unwrap();
|
||||
|
||||
assert!(single_proxy.contains(&test_ip));
|
||||
assert!(!single_proxy.contains(&test_ip2));
|
||||
|
||||
// 测试 CIDR 范围
|
||||
let cidr_proxy = TrustedProxy::Cidr("192.168.1.0/24".parse().unwrap());
|
||||
|
||||
assert!(cidr_proxy.contains(&test_ip));
|
||||
assert!(cidr_proxy.contains(&test_ip2));
|
||||
|
||||
let test_ip3: IpAddr = "192.168.2.1".parse().unwrap();
|
||||
assert!(!cidr_proxy.contains(&test_ip3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_private_network_check() {
|
||||
let config = TrustedProxyConfig::new(
|
||||
Vec::new(),
|
||||
ValidationMode::Lenient,
|
||||
true,
|
||||
10,
|
||||
true,
|
||||
vec!["10.0.0.0/8".parse().unwrap(), "192.168.0.0/16".parse().unwrap()],
|
||||
);
|
||||
|
||||
let private_ip: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
let private_ip2: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
let public_ip: IpAddr = "8.8.8.8".parse().unwrap();
|
||||
|
||||
assert!(config.is_private_network(&private_ip));
|
||||
assert!(config.is_private_network(&private_ip2));
|
||||
assert!(!config.is_private_network(&public_ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_ip_list_from_env() {
|
||||
use crate::config::env::parse_ip_list_from_env;
|
||||
|
||||
// 测试有效的 IP 列表
|
||||
std::env::set_var("TEST_IP_LIST", "10.0.0.0/8,192.168.1.0/24");
|
||||
|
||||
let result = parse_ip_list_from_env("TEST_IP_LIST", "");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let networks = result.unwrap();
|
||||
assert_eq!(networks.len(), 2);
|
||||
|
||||
// 测试空值
|
||||
std::env::set_var("TEST_IP_LIST_EMPTY", "");
|
||||
let result = parse_ip_list_from_env("TEST_IP_LIST_EMPTY", "");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
|
||||
// 测试无效值
|
||||
std::env::set_var("TEST_IP_LIST_INVALID", "invalid,10.0.0.0/8");
|
||||
let result = parse_ip_list_from_env("TEST_IP_LIST_INVALID", "");
|
||||
assert!(result.is_ok()); // 无效项会被跳过
|
||||
|
||||
// 清理环境变量
|
||||
std::env::remove_var("TEST_IP_LIST");
|
||||
std::env::remove_var("TEST_IP_LIST_EMPTY");
|
||||
std::env::remove_var("TEST_IP_LIST_INVALID");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_values() {
|
||||
use crate::config::env::{
|
||||
DEFAULT_PROXY_ENABLE_RFC7239, DEFAULT_PROXY_MAX_HOPS, DEFAULT_PROXY_VALIDATION_MODE, DEFAULT_TRUSTED_PROXIES,
|
||||
};
|
||||
|
||||
assert_eq!(DEFAULT_TRUSTED_PROXIES, "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8");
|
||||
assert_eq!(DEFAULT_PROXY_VALIDATION_MODE, "hop_by_hop");
|
||||
assert_eq!(DEFAULT_PROXY_MAX_HOPS, 10);
|
||||
assert!(DEFAULT_PROXY_ENABLE_RFC7239);
|
||||
}
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! IP utility tests
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::utils::ip::IpUtils;
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_ip_address() {
|
||||
// 测试有效 IP
|
||||
let valid_ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(IpUtils::is_valid_ip_address(&valid_ip));
|
||||
|
||||
// 测试未指定地址
|
||||
let unspecified_ip: IpAddr = "0.0.0.0".parse().unwrap();
|
||||
assert!(!IpUtils::is_valid_ip_address(&unspecified_ip));
|
||||
|
||||
// 测试多播地址
|
||||
let multicast_ip: IpAddr = "224.0.0.1".parse().unwrap();
|
||||
assert!(!IpUtils::is_valid_ip_address(&multicast_ip));
|
||||
|
||||
// 测试 IPv6
|
||||
let valid_ipv6: IpAddr = "2001:db8::1".parse().unwrap();
|
||||
assert!(IpUtils::is_valid_ip_address(&valid_ipv6));
|
||||
|
||||
let unspecified_ipv6: IpAddr = "::".parse().unwrap();
|
||||
assert!(!IpUtils::is_valid_ip_address(&unspecified_ipv6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_reserved_ip() {
|
||||
// 测试私有地址
|
||||
let private_ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&private_ip));
|
||||
|
||||
// 测试回环地址
|
||||
let loopback_ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&loopback_ip));
|
||||
|
||||
// 测试链路本地地址
|
||||
let link_local_ip: IpAddr = "169.254.0.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&link_local_ip));
|
||||
|
||||
// 测试文档地址
|
||||
let documentation_ip: IpAddr = "192.0.2.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&documentation_ip));
|
||||
|
||||
// 测试公网地址
|
||||
let public_ip: IpAddr = "8.8.8.8".parse().unwrap();
|
||||
assert!(!IpUtils::is_reserved_ip(&public_ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_private_ip() {
|
||||
// 测试 10.0.0.0/8
|
||||
assert!(IpUtils::is_private_ip(&"10.0.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_private_ip(&"10.255.255.254".parse().unwrap()));
|
||||
|
||||
// 测试 172.16.0.0/12
|
||||
assert!(IpUtils::is_private_ip(&"172.16.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_private_ip(&"172.31.255.254".parse().unwrap()));
|
||||
assert!(!IpUtils::is_private_ip(&"172.15.0.1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_private_ip(&"172.32.0.1".parse().unwrap()));
|
||||
|
||||
// 测试 192.168.0.0/16
|
||||
assert!(IpUtils::is_private_ip(&"192.168.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_private_ip(&"192.168.255.254".parse().unwrap()));
|
||||
|
||||
// 测试公网地址
|
||||
assert!(!IpUtils::is_private_ip(&"8.8.8.8".parse().unwrap()));
|
||||
assert!(!IpUtils::is_private_ip(&"203.0.113.1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_loopback_ip() {
|
||||
// IPv4 回环地址
|
||||
assert!(IpUtils::is_loopback_ip(&"127.0.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_loopback_ip(&"127.255.255.254".parse().unwrap()));
|
||||
|
||||
// IPv6 回环地址
|
||||
assert!(IpUtils::is_loopback_ip(&"::1".parse().unwrap()));
|
||||
|
||||
// 非回环地址
|
||||
assert!(!IpUtils::is_loopback_ip(&"192.168.1.1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_loopback_ip(&"2001:db8::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_link_local_ip() {
|
||||
// IPv4 链路本地地址
|
||||
assert!(IpUtils::is_link_local_ip(&"169.254.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_link_local_ip(&"169.254.255.254".parse().unwrap()));
|
||||
|
||||
// IPv6 链路本地地址
|
||||
assert!(IpUtils::is_link_local_ip(&"fe80::1".parse().unwrap()));
|
||||
assert!(IpUtils::is_link_local_ip(&"fe80::abcd:1234:5678:9abc".parse().unwrap()));
|
||||
|
||||
// 非链路本地地址
|
||||
assert!(!IpUtils::is_link_local_ip(&"192.168.1.1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_link_local_ip(&"2001:db8::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_documentation_ip() {
|
||||
// IPv4 文档地址
|
||||
assert!(IpUtils::is_documentation_ip(&"192.0.2.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_documentation_ip(&"198.51.100.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_documentation_ip(&"203.0.113.1".parse().unwrap()));
|
||||
|
||||
// IPv6 文档地址
|
||||
assert!(IpUtils::is_documentation_ip(&"2001:db8::1".parse().unwrap()));
|
||||
|
||||
// 非文档地址
|
||||
assert!(!IpUtils::is_documentation_ip(&"8.8.8.8".parse().unwrap()));
|
||||
assert!(!IpUtils::is_documentation_ip(&"2001:4860::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_ip_or_cidr() {
|
||||
// 测试单个 IP
|
||||
let result = IpUtils::parse_ip_or_cidr("192.168.1.1");
|
||||
assert!(result.is_ok());
|
||||
|
||||
// 测试 CIDR
|
||||
let result = IpUtils::parse_ip_or_cidr("192.168.1.0/24");
|
||||
assert!(result.is_ok());
|
||||
|
||||
// 测试 IPv6
|
||||
let result = IpUtils::parse_ip_or_cidr("2001:db8::1");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = IpUtils::parse_ip_or_cidr("2001:db8::/32");
|
||||
assert!(result.is_ok());
|
||||
|
||||
// 测试无效输入
|
||||
let result = IpUtils::parse_ip_or_cidr("invalid");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_ip_list() {
|
||||
// 测试有效的 IP 列表
|
||||
let result = IpUtils::parse_ip_list("192.168.1.1, 10.0.0.1, 8.8.8.8");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let ips = result.unwrap();
|
||||
assert_eq!(ips.len(), 3);
|
||||
assert_eq!(ips[0], IpAddr::from_str("192.168.1.1").unwrap());
|
||||
assert_eq!(ips[1], IpAddr::from_str("10.0.0.1").unwrap());
|
||||
assert_eq!(ips[2], IpAddr::from_str("8.8.8.8").unwrap());
|
||||
|
||||
// 测试带空格的 IP 列表
|
||||
let result = IpUtils::parse_ip_list("192.168.1.1,10.0.0.1");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 2);
|
||||
|
||||
// 测试空列表
|
||||
let result = IpUtils::parse_ip_list("");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
|
||||
// 测试无效 IP
|
||||
let result = IpUtils::parse_ip_list("192.168.1.1, invalid");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_network_list() {
|
||||
// 测试有效的网络列表
|
||||
let result = IpUtils::parse_network_list("192.168.1.0/24, 10.0.0.0/8");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let networks = result.unwrap();
|
||||
assert_eq!(networks.len(), 2);
|
||||
|
||||
// 测试单个 IP(会被当作/32 或/128 网络)
|
||||
let result = IpUtils::parse_network_list("192.168.1.1");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
|
||||
// 测试无效网络
|
||||
let result = IpUtils::parse_network_list("192.168.1.0/24, invalid");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_in_networks() {
|
||||
let networks = vec!["10.0.0.0/8".parse().unwrap(), "192.168.1.0/24".parse().unwrap()];
|
||||
|
||||
let ip_in_network: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
let ip_in_network2: IpAddr = "192.168.1.100".parse().unwrap();
|
||||
let ip_not_in_network: IpAddr = "8.8.8.8".parse().unwrap();
|
||||
|
||||
assert!(IpUtils::ip_in_networks(&ip_in_network, &networks));
|
||||
assert!(IpUtils::ip_in_networks(&ip_in_network2, &networks));
|
||||
assert!(!IpUtils::ip_in_networks(&ip_not_in_network, &networks));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_ip_type() {
|
||||
assert_eq!(IpUtils::get_ip_type(&"10.0.0.1".parse().unwrap()), "private");
|
||||
assert_eq!(IpUtils::get_ip_type(&"127.0.0.1".parse().unwrap()), "loopback");
|
||||
assert_eq!(IpUtils::get_ip_type(&"169.254.0.1".parse().unwrap()), "link_local");
|
||||
assert_eq!(IpUtils::get_ip_type(&"192.0.2.1".parse().unwrap()), "documentation");
|
||||
assert_eq!(IpUtils::get_ip_type(&"224.0.0.1".parse().unwrap()), "reserved");
|
||||
assert_eq!(IpUtils::get_ip_type(&"8.8.8.8".parse().unwrap()), "public");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_canonical_ip() {
|
||||
// 测试 IPv4
|
||||
let ipv4: IpAddr = "192.168.001.001".parse().unwrap();
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv4), "192.168.1.1");
|
||||
|
||||
// 测试 IPv6 压缩
|
||||
let ipv6_full: IpAddr = "2001:0db8:0000:0000:0000:0000:0000:0001".parse().unwrap();
|
||||
let ipv6_compressed: IpAddr = "2001:db8::1".parse().unwrap();
|
||||
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv6_full), "2001:db8::1");
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv6_compressed), "2001:db8::1");
|
||||
|
||||
// 测试包含多个零段的 IPv6
|
||||
let ipv6_multi_zero: IpAddr = "2001:0db8:0000:0000:abcd:0000:0000:1234".parse().unwrap();
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv6_multi_zero), "2001:db8::abcd:0:0:1234");
|
||||
}
|
||||
}
|
||||
@@ -1,637 +0,0 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Validation utility unit tests
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use http::HeaderMap;
|
||||
use std::net::IpAddr;
|
||||
|
||||
use crate::utils::ip::IpUtils;
|
||||
use crate::utils::validation::ValidationUtils;
|
||||
|
||||
/// 测试电子邮件验证
|
||||
#[test]
|
||||
fn test_email_validation() {
|
||||
// 有效的电子邮件地址
|
||||
assert!(ValidationUtils::is_valid_email("user@example.com"));
|
||||
assert!(ValidationUtils::is_valid_email("first.last@example.co.uk"));
|
||||
assert!(ValidationUtils::is_valid_email("user123@example.org"));
|
||||
assert!(ValidationUtils::is_valid_email("user+tag@example.com"));
|
||||
assert!(ValidationUtils::is_valid_email("user_name@example-domain.com"));
|
||||
|
||||
// 无效的电子邮件地址
|
||||
assert!(!ValidationUtils::is_valid_email(""));
|
||||
assert!(!ValidationUtils::is_valid_email("invalid-email"));
|
||||
assert!(!ValidationUtils::is_valid_email("user@"));
|
||||
assert!(!ValidationUtils::is_valid_email("@example.com"));
|
||||
assert!(!ValidationUtils::is_valid_email("user@.com"));
|
||||
assert!(!ValidationUtils::is_valid_email("user@example."));
|
||||
assert!(!ValidationUtils::is_valid_email("user@example..com"));
|
||||
assert!(!ValidationUtils::is_valid_email("user@example_com"));
|
||||
assert!(!ValidationUtils::is_valid_email("user@[127.0.0.1]"));
|
||||
assert!(!ValidationUtils::is_valid_email("user name@example.com"));
|
||||
assert!(!ValidationUtils::is_valid_email("user@exa mple.com"));
|
||||
assert!(!ValidationUtils::is_valid_email("user@example.c"));
|
||||
}
|
||||
|
||||
/// 测试 URL 验证
|
||||
#[test]
|
||||
fn test_url_validation() {
|
||||
// 有效的 URL
|
||||
assert!(ValidationUtils::is_valid_url("https://example.com"));
|
||||
assert!(ValidationUtils::is_valid_url("http://example.com"));
|
||||
assert!(ValidationUtils::is_valid_url("example.com"));
|
||||
assert!(ValidationUtils::is_valid_url("sub.example.com"));
|
||||
assert!(ValidationUtils::is_valid_url("example.co.uk"));
|
||||
assert!(ValidationUtils::is_valid_url("example.com/path"));
|
||||
assert!(ValidationUtils::is_valid_url("example.com/path/to/resource"));
|
||||
assert!(ValidationUtils::is_valid_url("example.com/?query=param"));
|
||||
assert!(ValidationUtils::is_valid_url("sub-domain.example-domain.com"));
|
||||
|
||||
// 无效的 URL
|
||||
assert!(!ValidationUtils::is_valid_url(""));
|
||||
assert!(!ValidationUtils::is_valid_url("invalid"));
|
||||
assert!(!ValidationUtils::is_valid_url("example"));
|
||||
assert!(!ValidationUtils::is_valid_url("example."));
|
||||
assert!(!ValidationUtils::is_valid_url(".com"));
|
||||
assert!(!ValidationUtils::is_valid_url("http://"));
|
||||
assert!(!ValidationUtils::is_valid_url("https://"));
|
||||
assert!(!ValidationUtils::is_valid_url("://example.com"));
|
||||
assert!(!ValidationUtils::is_valid_url("example..com"));
|
||||
assert!(!ValidationUtils::is_valid_url("-example.com"));
|
||||
assert!(!ValidationUtils::is_valid_url("example-.com"));
|
||||
assert!(!ValidationUtils::is_valid_url("example_com"));
|
||||
}
|
||||
|
||||
/// 测试 X-Forwarded-For 头部验证
|
||||
#[test]
|
||||
fn test_x_forwarded_for_validation() {
|
||||
// 有效的 X-Forwarded-For 头部
|
||||
assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195"));
|
||||
assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195, 198.51.100.1"));
|
||||
assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195,198.51.100.1,10.0.1.100"));
|
||||
assert!(ValidationUtils::validate_x_forwarded_for("2001:db8::1"));
|
||||
assert!(ValidationUtils::validate_x_forwarded_for("2001:db8::1, 2001:db8::2"));
|
||||
assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195:8080, 198.51.100.1:443")); // 带端口
|
||||
|
||||
// 无效的 X-Forwarded-For 头部
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for("")); // 空字符串
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for(" ")); // 只有空格
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for("invalid")); // 无效 IP
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195, invalid")); // 部分无效
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195, ")); // 尾部逗号加空格
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for(",203.0.113.195")); // 开头逗号
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for("203.0.113.195,,198.51.100.1")); // 连续逗号
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for("256.256.256.256")); // 超出范围的 IP
|
||||
}
|
||||
|
||||
/// 测试 Forwarded 头部验证 (RFC 7239)
|
||||
#[test]
|
||||
fn test_forwarded_header_validation() {
|
||||
// 有效的 Forwarded 头部
|
||||
assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60"));
|
||||
assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60;proto=http"));
|
||||
assert!(ValidationUtils::validate_forwarded_header("for=\"[2001:db8:cafe::17]\";proto=https"));
|
||||
assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.43, for=198.51.100.17"));
|
||||
assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60;proto=http;by=203.0.113.43"));
|
||||
assert!(ValidationUtils::validate_forwarded_header(
|
||||
"by=203.0.113.43;for=192.0.2.60;host=example.com;proto=https"
|
||||
));
|
||||
|
||||
// 无效的 Forwarded 头部
|
||||
assert!(!ValidationUtils::validate_forwarded_header("")); // 空字符串
|
||||
assert!(!ValidationUtils::validate_forwarded_header(" ")); // 只有空格
|
||||
assert!(!ValidationUtils::validate_forwarded_header("invalid")); // 无效格式
|
||||
assert!(!ValidationUtils::validate_forwarded_header("for=192.0.2.60 proto=http")); // 缺少分号
|
||||
assert!(!ValidationUtils::validate_forwarded_header("for;192.0.2.60")); // 缺少等号
|
||||
assert!(!ValidationUtils::validate_forwarded_header("=192.0.2.60")); // 缺少键
|
||||
assert!(!ValidationUtils::validate_forwarded_header("for=")); // 缺少值
|
||||
}
|
||||
|
||||
/// 测试 IP 范围验证
|
||||
#[test]
|
||||
fn test_ip_in_range_validation() {
|
||||
let cidr_ranges = vec![
|
||||
"10.0.0.0/8".to_string(),
|
||||
"192.168.0.0/16".to_string(),
|
||||
"172.16.0.0/12".to_string(),
|
||||
"2001:db8::/32".to_string(),
|
||||
];
|
||||
|
||||
// IP 在范围内
|
||||
let ip_in_range: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
assert!(ValidationUtils::validate_ip_in_range(&ip_in_range, &cidr_ranges));
|
||||
|
||||
let ip_in_range2: IpAddr = "192.168.1.100".parse().unwrap();
|
||||
assert!(ValidationUtils::validate_ip_in_range(&ip_in_range2, &cidr_ranges));
|
||||
|
||||
let ip_in_range3: IpAddr = "172.16.0.1".parse().unwrap();
|
||||
assert!(ValidationUtils::validate_ip_in_range(&ip_in_range3, &cidr_ranges));
|
||||
|
||||
let ipv6_in_range: IpAddr = "2001:db8::1".parse().unwrap();
|
||||
assert!(ValidationUtils::validate_ip_in_range(&ipv6_in_range, &cidr_ranges));
|
||||
|
||||
// IP 不在范围内
|
||||
let ip_not_in_range: IpAddr = "8.8.8.8".parse().unwrap();
|
||||
assert!(!ValidationUtils::validate_ip_in_range(&ip_not_in_range, &cidr_ranges));
|
||||
|
||||
let ip_not_in_range2: IpAddr = "203.0.113.1".parse().unwrap();
|
||||
assert!(!ValidationUtils::validate_ip_in_range(&ip_not_in_range2, &cidr_ranges));
|
||||
|
||||
let ipv6_not_in_range: IpAddr = "2001:4860::1".parse().unwrap();
|
||||
assert!(!ValidationUtils::validate_ip_in_range(&ipv6_not_in_range, &cidr_ranges));
|
||||
|
||||
// 空范围列表
|
||||
assert!(!ValidationUtils::validate_ip_in_range(&ip_in_range, &Vec::new()));
|
||||
|
||||
// 无效的 CIDR 范围(应该被忽略)
|
||||
let invalid_ranges = vec![
|
||||
"invalid".to_string(),
|
||||
"10.0.0.0/8".to_string(), // 这个有效
|
||||
];
|
||||
|
||||
let test_ip: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
// 即使有无效范围,只要有一个有效范围包含 IP,就应该返回 true
|
||||
assert!(ValidationUtils::validate_ip_in_range(&test_ip, &invalid_ranges));
|
||||
}
|
||||
|
||||
/// 测试头部值验证
|
||||
#[test]
|
||||
fn test_header_value_validation() {
|
||||
// 有效的头部值
|
||||
assert!(ValidationUtils::validate_header_value("text/plain"));
|
||||
assert!(ValidationUtils::validate_header_value("application/json; charset=utf-8"));
|
||||
assert!(ValidationUtils::validate_header_value("Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"));
|
||||
assert!(ValidationUtils::validate_header_value("127.0.0.1:8080"));
|
||||
assert!(ValidationUtils::validate_header_value("")); // 空字符串是有效的
|
||||
assert!(ValidationUtils::validate_header_value("normal text with spaces")); // 普通文本
|
||||
assert!(ValidationUtils::validate_header_value("value\twith\ttabs")); // 包含制表符
|
||||
|
||||
// 长但有效的头部值(边界情况)
|
||||
let long_value = "a".repeat(8192);
|
||||
assert!(ValidationUtils::validate_header_value(&long_value));
|
||||
|
||||
// 无效的头部值
|
||||
let too_long_value = "a".repeat(8193); // 超过长度限制
|
||||
assert!(!ValidationUtils::validate_header_value(&too_long_value));
|
||||
|
||||
// 包含控制字符(除了制表符、换行符、回车符)
|
||||
assert!(!ValidationUtils::validate_header_value("value\x00with_null")); // 空字符
|
||||
assert!(!ValidationUtils::validate_header_value("value\x01with_soh")); // 标题开始
|
||||
assert!(!ValidationUtils::validate_header_value("value\x1fwith_us")); // 单元分隔符
|
||||
|
||||
// 换行符和回车符是允许的(在某些上下文中)
|
||||
assert!(ValidationUtils::validate_header_value("line1\nline2"));
|
||||
assert!(ValidationUtils::validate_header_value("line1\r\nline2"));
|
||||
}
|
||||
|
||||
/// 测试头部映射验证
|
||||
#[test]
|
||||
fn test_headers_validation() {
|
||||
let mut valid_headers = HeaderMap::new();
|
||||
valid_headers.insert("Content-Type", "application/json".parse().unwrap());
|
||||
valid_headers.insert("Authorization", "Bearer token123".parse().unwrap());
|
||||
valid_headers.insert("X-Forwarded-For", "203.0.113.195".parse().unwrap());
|
||||
|
||||
assert!(ValidationUtils::validate_headers(&valid_headers));
|
||||
|
||||
// 测试头部名称过长
|
||||
let mut invalid_headers = HeaderMap::new();
|
||||
let long_name = "X-".to_string() + &"A".repeat(300); // 超过 256 字符
|
||||
invalid_headers.insert(long_name, "value".parse().unwrap());
|
||||
|
||||
assert!(!ValidationUtils::validate_headers(&invalid_headers));
|
||||
|
||||
// 测试头部值过长
|
||||
let mut invalid_headers2 = HeaderMap::new();
|
||||
let long_value = "A".repeat(8193); // 超过 8192 字节
|
||||
invalid_headers2.insert("X-Custom-Header", long_value.parse().unwrap());
|
||||
|
||||
assert!(!ValidationUtils::validate_headers(&invalid_headers2));
|
||||
|
||||
// 测试二进制数据(无法转换为字符串)
|
||||
let mut binary_headers = HeaderMap::new();
|
||||
let binary_data = vec![0x00, 0x01, 0x02, 0x03];
|
||||
binary_headers.insert("X-Binary-Data", http::HeaderValue::from_bytes(&binary_data).unwrap());
|
||||
|
||||
// 二进制数据应该通过验证(只要长度不超过限制)
|
||||
assert!(ValidationUtils::validate_headers(&binary_headers));
|
||||
|
||||
// 测试过长的二进制数据
|
||||
let mut long_binary_headers = HeaderMap::new();
|
||||
let long_binary_data = vec![0x00; 8193]; // 超过 8192 字节
|
||||
long_binary_headers.insert("X-Long-Binary", http::HeaderValue::from_bytes(&long_binary_data).unwrap());
|
||||
|
||||
assert!(!ValidationUtils::validate_headers(&long_binary_headers));
|
||||
}
|
||||
|
||||
/// 测试端口号验证
|
||||
#[test]
|
||||
fn test_port_validation() {
|
||||
// 有效端口号
|
||||
assert!(ValidationUtils::validate_port(1));
|
||||
assert!(ValidationUtils::validate_port(80));
|
||||
assert!(ValidationUtils::validate_port(443));
|
||||
assert!(ValidationUtils::validate_port(8080));
|
||||
assert!(ValidationUtils::validate_port(65535));
|
||||
|
||||
// 无效端口号
|
||||
assert!(!ValidationUtils::validate_port(0)); // 端口 0 是保留的
|
||||
assert!(!ValidationUtils::validate_port(65536)); // 超过最大值
|
||||
assert!(!ValidationUtils::validate_port(70000)); // 远超过最大值
|
||||
}
|
||||
|
||||
/// 测试 CIDR 表示法验证
|
||||
#[test]
|
||||
fn test_cidr_validation() {
|
||||
// 有效的 CIDR 表示法
|
||||
assert!(ValidationUtils::validate_cidr("192.168.1.0/24"));
|
||||
assert!(ValidationUtils::validate_cidr("10.0.0.0/8"));
|
||||
assert!(ValidationUtils::validate_cidr("0.0.0.0/0")); // 默认路由
|
||||
assert!(ValidationUtils::validate_cidr("2001:db8::/32"));
|
||||
assert!(ValidationUtils::validate_cidr("::/0")); // IPv6 默认路由
|
||||
assert!(ValidationUtils::validate_cidr("192.168.1.1/32")); // 单个主机
|
||||
assert!(ValidationUtils::validate_cidr("2001:db8::1/128")); // 单个 IPv6 主机
|
||||
|
||||
// 无效的 CIDR 表示法
|
||||
assert!(!ValidationUtils::validate_cidr("")); // 空字符串
|
||||
assert!(!ValidationUtils::validate_cidr("invalid")); // 无效格式
|
||||
assert!(!ValidationUtils::validate_cidr("192.168.1.0")); // 缺少前缀长度
|
||||
assert!(!ValidationUtils::validate_cidr("192.168.1.0/33")); // 前缀长度过大
|
||||
assert!(!ValidationUtils::validate_cidr("256.256.256.256/24")); // 无效 IP
|
||||
assert!(!ValidationUtils::validate_cidr("192.168.1.0/24/extra")); // 多余的部分
|
||||
assert!(!ValidationUtils::validate_cidr("192.168.1.0/-1")); // 负的前缀长度
|
||||
assert!(!ValidationUtils::validate_cidr("192.168.1.0/abc")); // 非数字前缀长度
|
||||
}
|
||||
|
||||
/// 测试代理链长度验证
|
||||
#[test]
|
||||
fn test_proxy_chain_length_validation() {
|
||||
let chain = vec![
|
||||
"203.0.113.195".parse().unwrap(),
|
||||
"198.51.100.1".parse().unwrap(),
|
||||
"10.0.1.100".parse().unwrap(),
|
||||
];
|
||||
|
||||
// 链长度在限制内
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&chain, 3));
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&chain, 5));
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&chain, 10));
|
||||
|
||||
// 链长度超过限制
|
||||
assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 2));
|
||||
assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 1));
|
||||
assert!(!ValidationUtils::validate_proxy_chain_length(&chain, 0));
|
||||
|
||||
// 空链
|
||||
let empty_chain: Vec<IpAddr> = Vec::new();
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 0));
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 1));
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 10));
|
||||
}
|
||||
|
||||
/// 测试代理链连续性验证
|
||||
#[test]
|
||||
fn test_proxy_chain_continuity_validation() {
|
||||
// 连续链(无重复相邻 IP)
|
||||
let continuous_chain = vec![
|
||||
"203.0.113.195".parse().unwrap(),
|
||||
"198.51.100.1".parse().unwrap(),
|
||||
"10.0.1.100".parse().unwrap(),
|
||||
];
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&continuous_chain));
|
||||
|
||||
// 不连续链(有重复相邻 IP)
|
||||
let discontinuous_chain = vec![
|
||||
"203.0.113.195".parse().unwrap(),
|
||||
"198.51.100.1".parse().unwrap(),
|
||||
"198.51.100.1".parse().unwrap(), // 重复
|
||||
"10.0.1.100".parse().unwrap(),
|
||||
];
|
||||
assert!(!ValidationUtils::validate_proxy_chain_continuity(&discontinuous_chain));
|
||||
|
||||
// 短链(应该总是连续的)
|
||||
let short_chain = vec!["203.0.113.195".parse().unwrap()];
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&short_chain));
|
||||
|
||||
let two_item_chain = vec!["203.0.113.195".parse().unwrap(), "198.51.100.1".parse().unwrap()];
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&two_item_chain));
|
||||
|
||||
// 空链(应该总是连续的)
|
||||
let empty_chain: Vec<IpAddr> = Vec::new();
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&empty_chain));
|
||||
|
||||
// 有多个重复的链
|
||||
let multi_duplicate_chain = vec![
|
||||
"203.0.113.195".parse().unwrap(),
|
||||
"203.0.113.195".parse().unwrap(), // 重复 1
|
||||
"198.51.100.1".parse().unwrap(),
|
||||
"198.51.100.1".parse().unwrap(), // 重复 2
|
||||
];
|
||||
assert!(!ValidationUtils::validate_proxy_chain_continuity(&multi_duplicate_chain));
|
||||
}
|
||||
|
||||
/// 测试安全字符串验证
|
||||
#[test]
|
||||
fn test_safe_string_validation() {
|
||||
// 安全字符串
|
||||
assert!(ValidationUtils::is_safe_string("example"));
|
||||
assert!(ValidationUtils::is_safe_string("example123"));
|
||||
assert!(ValidationUtils::is_safe_string("example-test"));
|
||||
assert!(ValidationUtils::is_safe_string("example.test"));
|
||||
assert!(ValidationUtils::is_safe_string("example~test"));
|
||||
assert!(ValidationUtils::is_safe_string("http://example.com/path"));
|
||||
assert!(ValidationUtils::is_safe_string("https://example.com/?query=param"));
|
||||
assert!(ValidationUtils::is_safe_string("user@example.com"));
|
||||
assert!(ValidationUtils::is_safe_string("192.168.1.1:8080"));
|
||||
assert!(ValidationUtils::is_safe_string("[2001:db8::1]:8080"));
|
||||
|
||||
// 不安全字符串
|
||||
assert!(!ValidationUtils::is_safe_string("")); // 空字符串
|
||||
assert!(!ValidationUtils::is_safe_string("example test")); // 包含空格
|
||||
assert!(!ValidationUtils::is_safe_string("example\ttest")); // 包含制表符
|
||||
assert!(!ValidationUtils::is_safe_string("example\ntest")); // 包含换行符
|
||||
assert!(!ValidationUtils::is_safe_string("example<script>alert('xss')</script>")); // 包含尖括号
|
||||
assert!(!ValidationUtils::is_safe_string("example\"test")); // 包含双引号
|
||||
assert!(!ValidationUtils::is_safe_string("example'test")); // 包含单引号
|
||||
assert!(!ValidationUtils::is_safe_string("example\\test")); // 包含反斜杠
|
||||
assert!(!ValidationUtils::is_safe_string("example`test")); // 包含反引号
|
||||
assert!(!ValidationUtils::is_safe_string("example|test")); // 包含竖线
|
||||
assert!(!ValidationUtils::is_safe_string("example$test")); // 包含美元符号
|
||||
assert!(!ValidationUtils::is_safe_string("example%test")); // 包含百分号
|
||||
assert!(!ValidationUtils::is_safe_string("example^test")); // 包含脱字符
|
||||
assert!(!ValidationUtils::is_safe_string("example&test")); // 包含和号
|
||||
assert!(!ValidationUtils::is_safe_string("example(test")); // 包含括号
|
||||
assert!(!ValidationUtils::is_safe_string("example)test")); // 包含括号
|
||||
assert!(!ValidationUtils::is_safe_string("example[test")); // 包含方括号
|
||||
assert!(!ValidationUtils::is_safe_string("example]test")); // 包含方括号
|
||||
assert!(!ValidationUtils::is_safe_string("example{test")); // 包含花括号
|
||||
assert!(!ValidationUtils::is_safe_string("example}test")); // 包含花括号
|
||||
}
|
||||
|
||||
/// 测试速率限制参数验证
|
||||
#[test]
|
||||
fn test_rate_limit_params_validation() {
|
||||
// 有效的速率限制参数
|
||||
assert!(ValidationUtils::validate_rate_limit_params(1, 1)); // 最小值
|
||||
assert!(ValidationUtils::validate_rate_limit_params(100, 60)); // 典型值
|
||||
assert!(ValidationUtils::validate_rate_limit_params(10000, 86400)); // 最大值
|
||||
|
||||
// 无效的速率限制参数
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(0, 60)); // 请求数为 0
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(10001, 60)); // 请求数超过最大值
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(100, 0)); // 周期为 0
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(100, 86401)); // 周期超过最大值
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(0, 0)); // 两者都为 0
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(100001, 100000)); // 两者都超过最大值
|
||||
}
|
||||
|
||||
/// 测试缓存参数验证
|
||||
#[test]
|
||||
fn test_cache_params_validation() {
|
||||
// 有效的缓存参数
|
||||
assert!(ValidationUtils::validate_cache_params(1, 1)); // 最小值
|
||||
assert!(ValidationUtils::validate_cache_params(10000, 300)); // 典型值
|
||||
assert!(ValidationUtils::validate_cache_params(1000000, 86400)); // 最大值
|
||||
|
||||
// 无效的缓存参数
|
||||
assert!(!ValidationUtils::validate_cache_params(0, 300)); // 容量为 0
|
||||
assert!(!ValidationUtils::validate_cache_params(1000001, 300)); // 容量超过最大值
|
||||
assert!(!ValidationUtils::validate_cache_params(10000, 0)); // TTL 为 0
|
||||
assert!(!ValidationUtils::validate_cache_params(10000, 86401)); // TTL 超过最大值
|
||||
assert!(!ValidationUtils::validate_cache_params(0, 0)); // 两者都为 0
|
||||
assert!(!ValidationUtils::validate_cache_params(2000000, 100000)); // 两者都超过最大值
|
||||
}
|
||||
|
||||
/// 测试敏感数据脱敏
|
||||
#[test]
|
||||
fn test_sensitive_data_masking() {
|
||||
let sensitive_patterns = vec!["password", "token", "secret", "authorization", "api_key"];
|
||||
|
||||
// 测试各种敏感字段的脱敏
|
||||
let test_cases = vec![
|
||||
(
|
||||
r#"{"username":"john","password":"secret123"}"#,
|
||||
r#"{"username":"john","password:[REDACTED]"}"#,
|
||||
),
|
||||
(r#"token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9&user=john"#, r#"token:[REDACTED]&user=john"#),
|
||||
(
|
||||
r#"Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"#,
|
||||
r#"Authorization:[REDACTED]"#,
|
||||
),
|
||||
(r#"api_key=sk_test_1234567890abcdef"#, r#"api_key:[REDACTED]"#),
|
||||
(r#"secret_key=abc123&public_key=xyz789"#, r#"secret_key:[REDACTED]&public_key=xyz789"#),
|
||||
(
|
||||
r#"password=123&password_confirmation=123"#,
|
||||
r#"password:[REDACTED]&password_confirmation:[REDACTED]"#,
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
let result = ValidationUtils::mask_sensitive_data(input, &sensitive_patterns);
|
||||
assert_eq!(result, expected, "Failed to mask: {}", input);
|
||||
}
|
||||
|
||||
// 测试不包含敏感数据的情况
|
||||
let safe_data = r#"{"name":"John","age":30,"city":"New York"}"#;
|
||||
let result = ValidationUtils::mask_sensitive_data(safe_data, &sensitive_patterns);
|
||||
assert_eq!(result, safe_data);
|
||||
|
||||
// 测试空模式列表
|
||||
let sensitive_data = r#"password=secret123"#;
|
||||
let result = ValidationUtils::mask_sensitive_data(sensitive_data, &Vec::new());
|
||||
assert_eq!(result, sensitive_data);
|
||||
|
||||
// 测试空输入
|
||||
let result = ValidationUtils::mask_sensitive_data("", &sensitive_patterns);
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
/// 测试组合验证场景
|
||||
#[test]
|
||||
fn test_combined_validation_scenarios() {
|
||||
// 场景 1:完整的代理请求验证
|
||||
let proxy_chain = vec![
|
||||
"203.0.113.195".parse().unwrap(),
|
||||
"198.51.100.1".parse().unwrap(),
|
||||
"10.0.1.100".parse().unwrap(),
|
||||
];
|
||||
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&proxy_chain, 10));
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&proxy_chain));
|
||||
|
||||
// 场景 2:包含无效数据的头部验证
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/json".parse().unwrap());
|
||||
headers.insert("X-Forwarded-For", "203.0.113.195, invalid, 10.0.1.100".parse().unwrap());
|
||||
|
||||
// 头部映射本身是有效的(即使包含无效的 X-Forwarded-For)
|
||||
assert!(ValidationUtils::validate_headers(&headers));
|
||||
|
||||
// 但 X-Forwarded-For 内容无效
|
||||
let xff_value = headers.get("X-Forwarded-For").unwrap().to_str().unwrap();
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for(xff_value));
|
||||
|
||||
// 场景 3:配置参数验证组合
|
||||
let cache_capacity = 10000;
|
||||
let cache_ttl = 300;
|
||||
let rate_limit_requests = 100;
|
||||
let rate_limit_period = 60;
|
||||
|
||||
assert!(ValidationUtils::validate_cache_params(cache_capacity, cache_ttl));
|
||||
assert!(ValidationUtils::validate_rate_limit_params(rate_limit_requests, rate_limit_period));
|
||||
|
||||
// 场景 4:IP 和 CIDR 验证组合
|
||||
let ip: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
let cidr = "10.0.0.0/8";
|
||||
|
||||
assert!(IpUtils::is_private_ip(&ip));
|
||||
assert!(ValidationUtils::validate_cidr(cidr));
|
||||
assert!(ValidationUtils::validate_ip_in_range(&ip, &[cidr.to_string()]));
|
||||
}
|
||||
|
||||
/// 测试边缘情况和边界值
|
||||
#[test]
|
||||
fn test_edge_cases_and_boundaries() {
|
||||
// 测试头部值的边界长度
|
||||
let max_length_value = "a".repeat(8192);
|
||||
let over_length_value = "a".repeat(8193);
|
||||
|
||||
assert!(ValidationUtils::validate_header_value(&max_length_value));
|
||||
assert!(!ValidationUtils::validate_header_value(&over_length_value));
|
||||
|
||||
// 测试端口边界值
|
||||
assert!(!ValidationUtils::validate_port(0)); // 最小无效值
|
||||
assert!(ValidationUtils::validate_port(1)); // 最小有效值
|
||||
assert!(ValidationUtils::validate_port(65535)); // 最大有效值
|
||||
assert!(!ValidationUtils::validate_port(65536)); // 超过最大值
|
||||
|
||||
// 测试 CIDR 前缀长度边界值
|
||||
assert!(ValidationUtils::validate_cidr("192.168.1.0/0")); // 最小有效前缀
|
||||
assert!(ValidationUtils::validate_cidr("192.168.1.0/32")); // 最大有效前缀
|
||||
assert!(!ValidationUtils::validate_cidr("192.168.1.0/33")); // 超过最大值
|
||||
|
||||
// IPv6 CIDR 前缀长度
|
||||
assert!(ValidationUtils::validate_cidr("2001:db8::/0")); // 最小有效前缀
|
||||
assert!(ValidationUtils::validate_cidr("2001:db8::/128")); // 最大有效前缀
|
||||
|
||||
// 测试代理链边界情况
|
||||
let empty_chain: Vec<IpAddr> = Vec::new();
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&empty_chain, 0));
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&empty_chain));
|
||||
|
||||
// 测试单个 IP 的链
|
||||
let single_ip_chain = vec!["192.168.1.1".parse().unwrap()];
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&single_ip_chain, 1));
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&single_ip_chain));
|
||||
|
||||
// 测试速率限制边界值
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(0, 60));
|
||||
assert!(ValidationUtils::validate_rate_limit_params(1, 1));
|
||||
assert!(ValidationUtils::validate_rate_limit_params(10000, 86400));
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(10001, 86400));
|
||||
assert!(!ValidationUtils::validate_rate_limit_params(10000, 86401));
|
||||
|
||||
// 测试缓存参数边界值
|
||||
assert!(!ValidationUtils::validate_cache_params(0, 300));
|
||||
assert!(ValidationUtils::validate_cache_params(1, 1));
|
||||
assert!(ValidationUtils::validate_cache_params(1000000, 86400));
|
||||
assert!(!ValidationUtils::validate_cache_params(1000001, 86400));
|
||||
assert!(!ValidationUtils::validate_cache_params(1000000, 86401));
|
||||
}
|
||||
|
||||
/// 测试性能敏感场景
|
||||
#[test]
|
||||
fn test_performance_sensitive_scenarios() {
|
||||
// 测试长代理链的处理
|
||||
let mut long_chain = Vec::new();
|
||||
for i in 0..100 {
|
||||
let ip = format!("10.0.{}.1", i % 256).parse().unwrap();
|
||||
long_chain.push(ip);
|
||||
}
|
||||
|
||||
// 应该能快速处理长链
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&long_chain, 100));
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&long_chain));
|
||||
|
||||
// 测试大量 CIDR 范围的验证
|
||||
let mut cidr_ranges = Vec::new();
|
||||
for i in 0..1000 {
|
||||
let cidr = format!("10.{}.0.0/16", i % 256);
|
||||
cidr_ranges.push(cidr);
|
||||
}
|
||||
|
||||
let test_ip: IpAddr = "10.128.1.1".parse().unwrap();
|
||||
// 应该能快速在大范围列表中查找
|
||||
let start = std::time::Instant::now();
|
||||
let result = ValidationUtils::validate_ip_in_range(&test_ip, &cidr_ranges);
|
||||
let duration = start.elapsed();
|
||||
|
||||
assert!(result);
|
||||
// 验证时间应该在合理范围内(比如小于 10 毫秒)
|
||||
assert!(duration < std::time::Duration::from_millis(10));
|
||||
|
||||
// 测试头部值验证的性能
|
||||
let large_header_value = "x".repeat(10000); // 超过 8192,应该快速拒绝
|
||||
let start = std::time::Instant::now();
|
||||
let result = ValidationUtils::validate_header_value(&large_header_value);
|
||||
let duration = start.elapsed();
|
||||
|
||||
assert!(!result); // 应该拒绝
|
||||
assert!(duration < std::time::Duration::from_millis(1)); // 应该非常快
|
||||
}
|
||||
|
||||
/// 测试实际代理场景模拟
|
||||
#[test]
|
||||
fn test_real_world_proxy_scenarios() {
|
||||
// 场景 1:典型的反向代理配置
|
||||
let typical_xff = "203.0.113.195, 198.51.100.1, 10.0.1.100";
|
||||
assert!(ValidationUtils::validate_x_forwarded_for(typical_xff));
|
||||
|
||||
let typical_proxy_chain: Vec<IpAddr> = typical_xff.split(',').map(|s| s.trim().parse().unwrap()).collect();
|
||||
|
||||
assert_eq!(typical_proxy_chain.len(), 3);
|
||||
assert!(ValidationUtils::validate_proxy_chain_length(&typical_proxy_chain, 10));
|
||||
assert!(ValidationUtils::validate_proxy_chain_continuity(&typical_proxy_chain));
|
||||
|
||||
// 场景 2:负载均衡器场景
|
||||
let lb_scenario = "2001:db8::1, 203.0.113.195, 198.51.100.1";
|
||||
assert!(ValidationUtils::validate_x_forwarded_for(lb_scenario));
|
||||
|
||||
// 场景 3:可能被攻击的头部
|
||||
let attack_headers = vec![
|
||||
("X-Forwarded-For", "127.0.0.1, 8.8.8.8, 192.168.1.1"),
|
||||
("X-Real-IP", "8.8.8.8"),
|
||||
("X-Forwarded-Host", "evil.com"),
|
||||
];
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
for (name, value) in attack_headers {
|
||||
headers.insert(name, value.parse().unwrap());
|
||||
}
|
||||
|
||||
// 头部格式本身应该是有效的
|
||||
assert!(ValidationUtils::validate_headers(&headers));
|
||||
|
||||
// 但内容可能需要进一步验证
|
||||
let xff_value = headers.get("X-Forwarded-For").unwrap().to_str().unwrap();
|
||||
assert!(ValidationUtils::validate_x_forwarded_for(xff_value));
|
||||
|
||||
// 场景 4:RFC 7239 格式
|
||||
let rfc7239_header = "for=192.0.2.60;proto=https;by=203.0.113.43";
|
||||
assert!(ValidationUtils::validate_forwarded_header(rfc7239_header));
|
||||
}
|
||||
}
|
||||
@@ -1,225 +0,0 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Proxy validator unit tests
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
|
||||
use crate::config::{TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
use crate::proxy::chain::ProxyChainAnalyzer;
|
||||
use crate::proxy::validator::{ClientInfo, ProxyValidator};
|
||||
|
||||
fn create_test_config() -> TrustedProxyConfig {
|
||||
let proxies = vec![
|
||||
TrustedProxy::Single("192.168.1.100".parse().unwrap()),
|
||||
TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()),
|
||||
TrustedProxy::Cidr("172.16.0.0/12".parse().unwrap()),
|
||||
];
|
||||
|
||||
TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 5, true, vec![])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_info_direct() {
|
||||
let addr = SocketAddr::new(IpAddr::from([192, 168, 1, 1]), 8080);
|
||||
let client_info = ClientInfo::direct(addr);
|
||||
|
||||
assert_eq!(client_info.real_ip, IpAddr::from([192, 168, 1, 1]));
|
||||
assert!(client_info.forwarded_host.is_none());
|
||||
assert!(client_info.forwarded_proto.is_none());
|
||||
assert!(!client_info.is_from_trusted_proxy);
|
||||
assert!(client_info.proxy_ip.is_none());
|
||||
assert_eq!(client_info.proxy_hops, 0);
|
||||
assert_eq!(client_info.validation_mode, ValidationMode::Lenient);
|
||||
assert!(client_info.warnings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_x_forwarded_for() {
|
||||
use crate::proxy::validator::ProxyValidator;
|
||||
|
||||
// 测试有效的 X-Forwarded-For 头部
|
||||
let header_value = "203.0.113.195, 198.51.100.1, 10.0.1.100";
|
||||
let result = ProxyValidator::parse_x_forwarded_for(header_value);
|
||||
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap());
|
||||
assert_eq!(result[1], IpAddr::from_str("198.51.100.1").unwrap());
|
||||
assert_eq!(result[2], IpAddr::from_str("10.0.1.100").unwrap());
|
||||
|
||||
// 测试带端口的 IP
|
||||
let header_value_with_ports = "203.0.113.195:8080, 198.51.100.1:443";
|
||||
let result = ProxyValidator::parse_x_forwarded_for(header_value_with_ports);
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0], IpAddr::from_str("203.0.113.195").unwrap());
|
||||
assert_eq!(result[1], IpAddr::from_str("198.51.100.1").unwrap());
|
||||
|
||||
// 测试空值
|
||||
let empty_result = ProxyValidator::parse_x_forwarded_for("");
|
||||
assert!(empty_result.is_empty());
|
||||
|
||||
// 测试无效 IP
|
||||
let invalid_result = ProxyValidator::parse_x_forwarded_for("invalid, 203.0.113.195");
|
||||
assert_eq!(invalid_result.len(), 1); // 无效项被跳过
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_chain_analyzer_lenient() {
|
||||
let config = create_test_config();
|
||||
let analyzer = ProxyChainAnalyzer::new(config.clone());
|
||||
|
||||
// 测试链:客户端 -> 可信代理 1 -> 可信代理 2
|
||||
let chain = vec![
|
||||
IpAddr::from_str("203.0.113.195").unwrap(), // 客户端
|
||||
IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1
|
||||
IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2
|
||||
];
|
||||
|
||||
let current_proxy = IpAddr::from_str("192.168.1.100").unwrap();
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
let result = analyzer.analyze_chain(&chain, current_proxy, &headers);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let analysis = result.unwrap();
|
||||
assert_eq!(analysis.client_ip, IpAddr::from_str("203.0.113.195").unwrap());
|
||||
assert_eq!(analysis.hops, 3);
|
||||
assert!(analysis.is_continuous);
|
||||
assert_eq!(analysis.validation_mode, ValidationMode::HopByHop);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_chain_analyzer_strict() {
|
||||
let mut config = create_test_config();
|
||||
config.validation_mode = ValidationMode::Strict;
|
||||
|
||||
let analyzer = ProxyChainAnalyzer::new(config);
|
||||
|
||||
// 测试链:客户端 -> 可信代理 1 -> 可信代理 2 (全部可信)
|
||||
let chain = vec![
|
||||
IpAddr::from_str("203.0.113.195").unwrap(), // 客户端
|
||||
IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1
|
||||
IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2
|
||||
];
|
||||
|
||||
let current_proxy = IpAddr::from_str("192.168.1.100").unwrap();
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
let result = analyzer.analyze_chain(&chain, current_proxy, &headers);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// 测试链包含不可信代理
|
||||
let chain_with_untrusted = vec![
|
||||
IpAddr::from_str("203.0.113.195").unwrap(), // 客户端
|
||||
IpAddr::from_str("8.8.8.8").unwrap(), // 不可信代理
|
||||
IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2
|
||||
];
|
||||
|
||||
let result = analyzer.analyze_chain(&chain_with_untrusted, current_proxy, &headers);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_chain_analyzer_hop_by_hop() {
|
||||
let config = create_test_config();
|
||||
let analyzer = ProxyChainAnalyzer::new(config);
|
||||
|
||||
// 测试链:客户端 -> 不可信代理 -> 可信代理 1 -> 可信代理 2
|
||||
let chain = vec![
|
||||
IpAddr::from_str("203.0.113.195").unwrap(), // 客户端
|
||||
IpAddr::from_str("8.8.8.8").unwrap(), // 不可信代理
|
||||
IpAddr::from_str("10.0.1.100").unwrap(), // 可信代理 1
|
||||
IpAddr::from_str("192.168.1.100").unwrap(), // 可信代理 2
|
||||
];
|
||||
|
||||
let current_proxy = IpAddr::from_str("192.168.1.100").unwrap();
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
let result = analyzer.analyze_chain(&chain, current_proxy, &headers);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let analysis = result.unwrap();
|
||||
// 应该找到客户端 IP (203.0.113.195)
|
||||
assert_eq!(analysis.client_ip, IpAddr::from_str("203.0.113.195").unwrap());
|
||||
// 应该验证 2 跳 (10.0.1.100 和 192.168.1.100)
|
||||
assert_eq!(analysis.hops, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_continuity_check() {
|
||||
let config = create_test_config();
|
||||
let analyzer = ProxyChainAnalyzer::new(config);
|
||||
|
||||
// 测试连续链
|
||||
let full_chain = vec![
|
||||
IpAddr::from_str("203.0.113.195").unwrap(),
|
||||
IpAddr::from_str("10.0.1.100").unwrap(),
|
||||
IpAddr::from_str("192.168.1.100").unwrap(),
|
||||
];
|
||||
|
||||
let trusted_chain = vec![
|
||||
IpAddr::from_str("10.0.1.100").unwrap(),
|
||||
IpAddr::from_str("192.168.1.100").unwrap(),
|
||||
];
|
||||
|
||||
assert!(analyzer.check_chain_continuity(&full_chain, &trusted_chain));
|
||||
|
||||
// 测试不连续链
|
||||
let bad_trusted_chain = vec![IpAddr::from_str("192.168.1.100").unwrap()];
|
||||
|
||||
assert!(!analyzer.check_chain_continuity(&full_chain, &bad_trusted_chain));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_ip_addresses() {
|
||||
let config = create_test_config();
|
||||
let analyzer = ProxyChainAnalyzer::new(config);
|
||||
|
||||
// 测试有效 IP
|
||||
let valid_chain = vec![
|
||||
IpAddr::from_str("203.0.113.195").unwrap(),
|
||||
IpAddr::from_str("10.0.1.100").unwrap(),
|
||||
];
|
||||
|
||||
let result = analyzer.validate_ip_addresses(&valid_chain);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// 测试未指定地址
|
||||
let invalid_chain = vec![IpAddr::from_str("0.0.0.0").unwrap()];
|
||||
|
||||
let result = analyzer.validate_ip_addresses(&invalid_chain);
|
||||
assert!(result.is_err());
|
||||
|
||||
// 测试多播地址
|
||||
let multicast_chain = vec![IpAddr::from_str("224.0.0.1").unwrap()];
|
||||
|
||||
let result = analyzer.validate_ip_addresses(&multicast_chain);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_validator_creation() {
|
||||
let config = create_test_config();
|
||||
let validator = ProxyValidator::new(config, None);
|
||||
|
||||
// 验证器应该成功创建
|
||||
assert!(true); // 如果没有 panic,测试通过
|
||||
}
|
||||
}
|
||||
49
crates/trusted-proxies/tests/integration/api_tests.rs
Normal file
49
crates/trusted-proxies/tests/integration/api_tests.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
use axum::body::Body;
|
||||
use axum::{routing::get, Router};
|
||||
use serde_json::{json};
|
||||
use tower::ServiceExt;
|
||||
use rustfs_trusted_proxies::config::{AppConfig, TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
use rustfs_trusted_proxies::state::AppState;
|
||||
|
||||
fn create_test_app_state() -> AppState {
|
||||
let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())];
|
||||
let proxy_config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]);
|
||||
let config = AppConfig::new(
|
||||
proxy_config,
|
||||
rustfs_trusted_proxies::config::CacheConfig::default(),
|
||||
rustfs_trusted_proxies::config::MonitoringConfig::default(),
|
||||
rustfs_trusted_proxies::config::CloudConfig::default(),
|
||||
"127.0.0.1:3000".parse().unwrap(),
|
||||
);
|
||||
AppState {
|
||||
config: Arc::new(config),
|
||||
metrics: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_check_endpoint() {
|
||||
let state = create_test_app_state();
|
||||
let app = Router::new()
|
||||
.route("/health", get(|| async { axum::response::Json(json!({"status": "healthy"})) }))
|
||||
.with_state(state);
|
||||
|
||||
let request = axum::http::Request::builder().uri("/health").body(Body::empty()).unwrap();
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
}
|
||||
30
crates/trusted-proxies/tests/integration/cloud_tests.rs
Normal file
30
crates/trusted-proxies/tests/integration/cloud_tests.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
use rustfs_trusted_proxies::cloud::detector::CloudDetector;
|
||||
use rustfs_trusted_proxies::cloud::metadata::AwsMetadataFetcher;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cloud_detector_disabled() {
|
||||
let detector = CloudDetector::new(false, Duration::from_secs(1), None);
|
||||
let provider = detector.detect_provider();
|
||||
assert!(provider.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aws_metadata_fetcher() {
|
||||
let fetcher = AwsMetadataFetcher::new();
|
||||
assert_eq!(fetcher.provider_name(), "aws");
|
||||
}
|
||||
@@ -12,13 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Integration tests for the trusted proxy system
|
||||
//! Integration tests for the trusted proxy system.
|
||||
|
||||
#[cfg(test)]
|
||||
mod api_tests;
|
||||
#[cfg(test)]
|
||||
mod cloud_tests;
|
||||
#[cfg(test)]
|
||||
mod proxy_tests;
|
||||
|
||||
// 重新导出测试模块
|
||||
pub use api_tests::*;
|
||||
pub use cloud_tests::*;
|
||||
pub use proxy_tests::*;
|
||||
39
crates/trusted-proxies/tests/integration/proxy_tests.rs
Normal file
39
crates/trusted-proxies/tests/integration/proxy_tests.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::{routing::get, Router};
|
||||
use tower::ServiceExt;
|
||||
use rustfs_trusted_proxies::config::{TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
use rustfs_trusted_proxies::middleware::TrustedProxyLayer;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_validation_flow() {
|
||||
let proxies = vec![TrustedProxy::Single("127.0.0.1".parse().unwrap())];
|
||||
let config = TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 10, true, vec![]);
|
||||
let proxy_layer = TrustedProxyLayer::enabled(config, None);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/test", get(|| async { "OK" }))
|
||||
.layer(proxy_layer);
|
||||
|
||||
let request = axum::http::Request::builder()
|
||||
.uri("/test")
|
||||
.header("X-Forwarded-For", "203.0.113.195")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), 200);
|
||||
}
|
||||
117
crates/trusted-proxies/tests/unit/config_tests.rs
Normal file
117
crates/trusted-proxies/tests/unit/config_tests.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::net::IpAddr;
|
||||
use rustfs_trusted_proxies::config::env::{DEFAULT_TRUSTED_PROXIES, ENV_TRUSTED_PROXIES};
|
||||
use rustfs_trusted_proxies::config::{ConfigLoader, TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
|
||||
#[test]
|
||||
fn test_config_loader_default() {
|
||||
std::env::remove_var(ENV_TRUSTED_PROXIES);
|
||||
let config = ConfigLoader::from_env_or_default();
|
||||
assert_eq!(config.server_addr.port(), 3000);
|
||||
assert!(!config.proxy.proxies.is_empty());
|
||||
assert_eq!(config.proxy.validation_mode, ValidationMode::HopByHop);
|
||||
assert!(config.proxy.enable_rfc7239);
|
||||
assert_eq!(config.proxy.max_hops, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_loader_env_vars() {
|
||||
std::env::set_var(ENV_TRUSTED_PROXIES, "192.168.1.0/24,10.0.0.0/8");
|
||||
std::env::set_var("TRUSTED_PROXY_VALIDATION_MODE", "strict");
|
||||
std::env::set_var("TRUSTED_PROXY_MAX_HOPS", "5");
|
||||
std::env::set_var("SERVER_PORT", "8080");
|
||||
|
||||
let config = ConfigLoader::from_env();
|
||||
|
||||
if let Ok(config) = config {
|
||||
assert_eq!(config.server_addr.port(), 8080);
|
||||
assert_eq!(config.proxy.validation_mode, ValidationMode::Strict);
|
||||
assert_eq!(config.proxy.max_hops, 5);
|
||||
|
||||
std::env::remove_var(ENV_TRUSTED_PROXIES);
|
||||
std::env::remove_var("TRUSTED_PROXY_VALIDATION_MODE");
|
||||
std::env::remove_var("TRUSTED_PROXY_MAX_HOPS");
|
||||
std::env::remove_var("SERVER_PORT");
|
||||
} else {
|
||||
panic!("Failed to load configuration from environment variables");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trusted_proxy_config() {
|
||||
let proxies = vec![
|
||||
TrustedProxy::Single("192.168.1.1".parse().unwrap()),
|
||||
TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()),
|
||||
];
|
||||
|
||||
let config = TrustedProxyConfig::new(proxies.clone(), ValidationMode::Strict, true, 10, true, vec![]);
|
||||
|
||||
assert_eq!(config.proxies.len(), 2);
|
||||
assert_eq!(config.validation_mode, ValidationMode::Strict);
|
||||
assert!(config.enable_rfc7239);
|
||||
assert_eq!(config.max_hops, 10);
|
||||
assert!(config.enable_chain_continuity_check);
|
||||
|
||||
let test_ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
let test_socket_addr = std::net::SocketAddr::new(test_ip, 8080);
|
||||
assert!(config.is_trusted(&test_socket_addr));
|
||||
|
||||
let test_ip2: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
let test_socket_addr2 = std::net::SocketAddr::new(test_ip2, 8080);
|
||||
assert!(config.is_trusted(&test_socket_addr2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trusted_proxy_contains() {
|
||||
let single_proxy = TrustedProxy::Single("192.168.1.1".parse().unwrap());
|
||||
let test_ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
let test_ip2: IpAddr = "192.168.1.2".parse().unwrap();
|
||||
|
||||
assert!(single_proxy.contains(&test_ip));
|
||||
assert!(!single_proxy.contains(&test_ip2));
|
||||
|
||||
let cidr_proxy = TrustedProxy::Cidr("192.168.1.0/24".parse().unwrap());
|
||||
assert!(cidr_proxy.contains(&test_ip));
|
||||
assert!(cidr_proxy.contains(&test_ip2));
|
||||
|
||||
let test_ip3: IpAddr = "192.168.2.1".parse().unwrap();
|
||||
assert!(!cidr_proxy.contains(&test_ip3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_private_network_check() {
|
||||
let config = TrustedProxyConfig::new(
|
||||
Vec::new(),
|
||||
ValidationMode::Lenient,
|
||||
true,
|
||||
10,
|
||||
true,
|
||||
vec!["10.0.0.0/8".parse().unwrap(), "192.168.0.0/16".parse().unwrap()],
|
||||
);
|
||||
|
||||
let private_ip: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
let private_ip2: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
let public_ip: IpAddr = "8.8.8.8".parse().unwrap();
|
||||
|
||||
assert!(config.is_private_network(&private_ip));
|
||||
assert!(config.is_private_network(&private_ip2));
|
||||
assert!(!config.is_private_network(&public_ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_values() {
|
||||
assert_eq!(DEFAULT_TRUSTED_PROXIES, "127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,fd00::/8");
|
||||
}
|
||||
194
crates/trusted-proxies/tests/unit/ip_tests.rs
Normal file
194
crates/trusted-proxies/tests/unit/ip_tests.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use rustfs_trusted_proxies::utils::IpUtils;
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_ip_address() {
|
||||
let valid_ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(IpUtils::is_valid_ip_address(&valid_ip));
|
||||
|
||||
let unspecified_ip: IpAddr = "0.0.0.0".parse().unwrap();
|
||||
assert!(!IpUtils::is_valid_ip_address(&unspecified_ip));
|
||||
|
||||
let multicast_ip: IpAddr = "224.0.0.1".parse().unwrap();
|
||||
assert!(!IpUtils::is_valid_ip_address(&multicast_ip));
|
||||
|
||||
let valid_ipv6: IpAddr = "2001:db8::1".parse().unwrap();
|
||||
assert!(IpUtils::is_valid_ip_address(&valid_ipv6));
|
||||
|
||||
let unspecified_ipv6: IpAddr = "::".parse().unwrap();
|
||||
assert!(!IpUtils::is_valid_ip_address(&unspecified_ipv6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_reserved_ip() {
|
||||
let private_ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&private_ip));
|
||||
|
||||
let loopback_ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&loopback_ip));
|
||||
|
||||
let link_local_ip: IpAddr = "169.254.0.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&link_local_ip));
|
||||
|
||||
let documentation_ip: IpAddr = "192.0.2.1".parse().unwrap();
|
||||
assert!(IpUtils::is_reserved_ip(&documentation_ip));
|
||||
|
||||
let public_ip: IpAddr = "8.8.8.8".parse().unwrap();
|
||||
assert!(!IpUtils::is_reserved_ip(&public_ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_private_ip() {
|
||||
assert!(IpUtils::is_private_ip(&"10.0.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_private_ip(&"10.255.255.254".parse().unwrap()));
|
||||
|
||||
assert!(IpUtils::is_private_ip(&"172.16.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_private_ip(&"172.31.255.254".parse().unwrap()));
|
||||
assert!(!IpUtils::is_private_ip(&"172.15.0.1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_private_ip(&"172.32.0.1".parse().unwrap()));
|
||||
|
||||
assert!(IpUtils::is_private_ip(&"192.168.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_private_ip(&"192.168.255.254".parse().unwrap()));
|
||||
|
||||
assert!(!IpUtils::is_private_ip(&"8.8.8.8".parse().unwrap()));
|
||||
assert!(!IpUtils::is_private_ip(&"203.0.113.1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_loopback_ip() {
|
||||
assert!(IpUtils::is_loopback_ip(&"127.0.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_loopback_ip(&"127.255.255.254".parse().unwrap()));
|
||||
assert!(IpUtils::is_loopback_ip(&"::1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_loopback_ip(&"192.168.1.1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_loopback_ip(&"2001:db8::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_link_local_ip() {
|
||||
assert!(IpUtils::is_link_local_ip(&"169.254.0.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_link_local_ip(&"169.254.255.254".parse().unwrap()));
|
||||
assert!(IpUtils::is_link_local_ip(&"fe80::1".parse().unwrap()));
|
||||
assert!(IpUtils::is_link_local_ip(&"fe80::abcd:1234:5678:9abc".parse().unwrap()));
|
||||
assert!(!IpUtils::is_link_local_ip(&"192.168.1.1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_link_local_ip(&"2001:db8::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_documentation_ip() {
|
||||
assert!(IpUtils::is_documentation_ip(&"192.0.2.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_documentation_ip(&"198.51.100.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_documentation_ip(&"203.0.113.1".parse().unwrap()));
|
||||
assert!(IpUtils::is_documentation_ip(&"2001:db8::1".parse().unwrap()));
|
||||
assert!(!IpUtils::is_documentation_ip(&"8.8.8.8".parse().unwrap()));
|
||||
assert!(!IpUtils::is_documentation_ip(&"2001:4860::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_ip_or_cidr() {
|
||||
let result = IpUtils::parse_ip_or_cidr("192.168.1.1");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = IpUtils::parse_ip_or_cidr("192.168.1.0/24");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = IpUtils::parse_ip_or_cidr("2001:db8::1");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = IpUtils::parse_ip_or_cidr("2001:db8::/32");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = IpUtils::parse_ip_or_cidr("invalid");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_ip_list() {
|
||||
let result = IpUtils::parse_ip_list("192.168.1.1, 10.0.0.1, 8.8.8.8");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let ips = result.unwrap();
|
||||
assert_eq!(ips.len(), 3);
|
||||
assert_eq!(ips[0], IpAddr::from_str("192.168.1.1").unwrap());
|
||||
assert_eq!(ips[1], IpAddr::from_str("10.0.0.1").unwrap());
|
||||
assert_eq!(ips[2], IpAddr::from_str("8.8.8.8").unwrap());
|
||||
|
||||
let result = IpUtils::parse_ip_list("192.168.1.1,10.0.0.1");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 2);
|
||||
|
||||
let result = IpUtils::parse_ip_list("");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_empty());
|
||||
|
||||
let result = IpUtils::parse_ip_list("192.168.1.1, invalid");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_network_list() {
|
||||
let result = IpUtils::parse_network_list("192.168.1.0/24, 10.0.0.0/8");
|
||||
assert!(result.is_ok());
|
||||
|
||||
let networks = result.unwrap();
|
||||
assert_eq!(networks.len(), 2);
|
||||
|
||||
let result = IpUtils::parse_network_list("192.168.1.1");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 1);
|
||||
|
||||
let result = IpUtils::parse_network_list("192.168.1.0/24, invalid");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_in_networks() {
|
||||
let networks = vec!["10.0.0.0/8".parse().unwrap(), "192.168.1.0/24".parse().unwrap()];
|
||||
|
||||
let ip_in_network: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
let ip_in_network2: IpAddr = "192.168.1.100".parse().unwrap();
|
||||
let ip_not_in_network: IpAddr = "8.8.8.8".parse().unwrap();
|
||||
|
||||
assert!(IpUtils::ip_in_networks(&ip_in_network, &networks));
|
||||
assert!(IpUtils::ip_in_networks(&ip_in_network2, &networks));
|
||||
assert!(!IpUtils::ip_in_networks(&ip_not_in_network, &networks));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_ip_type() {
|
||||
assert_eq!(IpUtils::get_ip_type(&"10.0.0.1".parse().unwrap()), "private");
|
||||
assert_eq!(IpUtils::get_ip_type(&"127.0.0.1".parse().unwrap()), "loopback");
|
||||
assert_eq!(IpUtils::get_ip_type(&"169.254.0.1".parse().unwrap()), "link_local");
|
||||
assert_eq!(IpUtils::get_ip_type(&"192.0.2.1".parse().unwrap()), "documentation");
|
||||
assert_eq!(IpUtils::get_ip_type(&"224.0.0.1".parse().unwrap()), "reserved");
|
||||
assert_eq!(IpUtils::get_ip_type(&"8.8.8.8".parse().unwrap()), "public");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_canonical_ip() {
|
||||
let ipv4: IpAddr = "192.168.001.001".parse().unwrap();
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv4), "192.168.1.1");
|
||||
|
||||
let ipv6_full: IpAddr = "2001:0db8:0000:0000:0000:0000:0000:0001".parse().unwrap();
|
||||
let ipv6_compressed: IpAddr = "2001:db8::1".parse().unwrap();
|
||||
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv6_full), "2001:db8::1");
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv6_compressed), "2001:db8::1");
|
||||
|
||||
let ipv6_multi_zero: IpAddr = "2001:0db8:0000:0000:abcd:0000:0000:1234".parse().unwrap();
|
||||
assert_eq!(IpUtils::canonical_ip(&ipv6_multi_zero), "2001:db8::abcd:0:0:1234");
|
||||
}
|
||||
@@ -12,15 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Unit tests for the trusted proxy system
|
||||
//! Unit tests for the trusted proxy system components.
|
||||
|
||||
#[cfg(test)]
|
||||
mod config_tests;
|
||||
#[cfg(test)]
|
||||
mod ip_tests;
|
||||
#[cfg(test)]
|
||||
mod validation_tests;
|
||||
#[cfg(test)]
|
||||
mod validator_tests;
|
||||
|
||||
// 重新导出测试模块
|
||||
pub use config_tests::*;
|
||||
pub use ip_tests::*;
|
||||
pub use validation_tests::*;
|
||||
pub use validator_tests::*;
|
||||
69
crates/trusted-proxies/tests/unit/validation_tests.rs
Normal file
69
crates/trusted-proxies/tests/unit/validation_tests.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use http::HeaderMap;
|
||||
use std::net::IpAddr;
|
||||
use rustfs_trusted_proxies::utils::{IpUtils, ValidationUtils};
|
||||
|
||||
#[test]
|
||||
fn test_email_validation() {
|
||||
assert!(ValidationUtils::is_valid_email("user@example.com"));
|
||||
assert!(!ValidationUtils::is_valid_email("invalid-email"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_validation() {
|
||||
assert!(ValidationUtils::is_valid_url("https://example.com"));
|
||||
assert!(!ValidationUtils::is_valid_url("invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_x_forwarded_for_validation() {
|
||||
assert!(ValidationUtils::validate_x_forwarded_for("203.0.113.195"));
|
||||
assert!(!ValidationUtils::validate_x_forwarded_for("invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forwarded_header_validation() {
|
||||
assert!(ValidationUtils::validate_forwarded_header("for=192.0.2.60"));
|
||||
assert!(!ValidationUtils::validate_forwarded_header("invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_in_range_validation() {
|
||||
let cidr_ranges = vec![
|
||||
"10.0.0.0/8".to_string(),
|
||||
"192.168.0.0/16".to_string(),
|
||||
];
|
||||
let ip: IpAddr = "10.0.1.1".parse().unwrap();
|
||||
assert!(ValidationUtils::validate_ip_in_range(&ip, &cidr_ranges));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_header_value_validation() {
|
||||
assert!(ValidationUtils::validate_header_value("text/plain"));
|
||||
assert!(!ValidationUtils::validate_header_value(&"a".repeat(8193)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_port_validation() {
|
||||
assert!(ValidationUtils::validate_port(80));
|
||||
assert!(!ValidationUtils::validate_port(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cidr_validation() {
|
||||
assert!(ValidationUtils::validate_cidr("192.168.1.0/24"));
|
||||
assert!(!ValidationUtils::validate_cidr("invalid"));
|
||||
}
|
||||
56
crates/trusted-proxies/tests/unit/validator_tests.rs
Normal file
56
crates/trusted-proxies/tests/unit/validator_tests.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use axum::http::HeaderMap;
|
||||
use rustfs_trusted_proxies::config::{TrustedProxy, TrustedProxyConfig, ValidationMode};
|
||||
use rustfs_trusted_proxies::proxy::chain::ProxyChainAnalyzer;
|
||||
use rustfs_trusted_proxies::proxy::validator::{ClientInfo, ProxyValidator};
|
||||
|
||||
fn create_test_config() -> TrustedProxyConfig {
|
||||
let proxies = vec![
|
||||
TrustedProxy::Single("192.168.1.100".parse().unwrap()),
|
||||
TrustedProxy::Cidr("10.0.0.0/8".parse().unwrap()),
|
||||
];
|
||||
TrustedProxyConfig::new(proxies, ValidationMode::HopByHop, true, 5, true, vec![])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_info_direct() {
|
||||
let addr = SocketAddr::new(IpAddr::from([192, 168, 1, 1]), 8080);
|
||||
let client_info = ClientInfo::direct(addr);
|
||||
assert_eq!(client_info.real_ip, IpAddr::from([192, 168, 1, 1]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_x_forwarded_for() {
|
||||
let header_value = "203.0.113.195, 198.51.100.1";
|
||||
let result = ProxyValidator::parse_x_forwarded_for(header_value);
|
||||
assert_eq!(result.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_chain_analyzer_hop_by_hop() {
|
||||
let config = create_test_config();
|
||||
let analyzer = ProxyChainAnalyzer::new(config);
|
||||
let chain = vec![
|
||||
IpAddr::from_str("203.0.113.195").unwrap(),
|
||||
IpAddr::from_str("10.0.1.100").unwrap(),
|
||||
];
|
||||
let current_proxy = IpAddr::from_str("192.168.1.100").unwrap();
|
||||
let headers = HeaderMap::new();
|
||||
let result = analyzer.analyze_chain(&chain, current_proxy, &headers);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# 对象路径中的特殊字符 - 完整文档
|
||||
|
||||
本目录包含关于在 RustFS 中处理 S3 对象路径中特殊字符(空格、加号、百分号等)的完整文档。
|
||||
本目录包含关于在 RustFS 中处理 S3 对象路径中特殊字符 (空格、加号、百分号等) 的完整文档。
|
||||
|
||||
## 快速链接
|
||||
|
||||
@@ -12,18 +12,18 @@
|
||||
|
||||
### 问题现象
|
||||
|
||||
用户报告了两个问题:
|
||||
1. **问题 A**: UI 可以导航到包含特殊字符的文件夹,但无法列出其中的内容
|
||||
用户报告了两个问题:
|
||||
1. **问题 A**: UI 可以导航到包含特殊字符的文件夹,但无法列出其中的内容
|
||||
2. **问题 B**: 上传路径中包含 `+` 号的文件时出现 400 错误
|
||||
|
||||
### 根本原因
|
||||
|
||||
经过深入调查,包括检查 s3s 库的源代码,我们发现:
|
||||
经过深入调查,包括检查 s3s 库的源代码,我们发现:
|
||||
|
||||
**后端 (RustFS) 工作正常** ✅
|
||||
- s3s 库正确地对 HTTP 请求中的对象键进行 URL 解码
|
||||
- RustFS 正确存储和检索包含特殊字符的对象
|
||||
- 命令行工具(mc, aws-cli)完美工作 → 证明后端正确处理特殊字符
|
||||
- 命令行工具 (mc, aws-cli) 完美工作 → 证明后端正确处理特殊字符
|
||||
|
||||
**问题出在 UI/客户端层** ❌
|
||||
- 某些客户端未正确进行 URL 编码
|
||||
@@ -33,8 +33,8 @@
|
||||
### 解决方案
|
||||
|
||||
1. **用户**: 使用正规的 S3 SDK/客户端(它们会自动处理编码)
|
||||
2. **开发者**: 后端无需修复,但添加了防御性验证和日志
|
||||
3. **UI**: UI 需要正确对所有请求进行 URL 编码(如适用)
|
||||
2. **开发者**: 后端无需修复,但添加了防御性验证和日志
|
||||
3. **UI**: UI 需要正确对所有请求进行 URL 编码 (如适用)
|
||||
|
||||
## URL 编码快速参考
|
||||
|
||||
@@ -44,11 +44,11 @@
|
||||
| 加号 | `+` | `%2B` | `%2B` |
|
||||
| 百分号 | `%` | `%25` | `%25` |
|
||||
|
||||
**重要**: 在 URL **路径**中,`+` = 字面加号(不是空格)。只有 `%20` = 空格!
|
||||
**重要**: 在 URL **路径**中,`+` = 字面加号 (不是空格)。只有 `%20` = 空格!
|
||||
|
||||
## 快速示例
|
||||
|
||||
### ✅ 正确使用(使用 mc)
|
||||
### ✅ 正确使用 (使用 mc)
|
||||
|
||||
```bash
|
||||
# 上传
|
||||
@@ -60,7 +60,7 @@ mc ls "myrustfs/bucket/路径 包含 空格/"
|
||||
# 结果: ✅ 成功 - mc 正确编码了请求
|
||||
```
|
||||
|
||||
### ❌ 可能失败(原始 HTTP 未编码)
|
||||
### ❌ 可能失败 (原始 HTTP 未编码)
|
||||
|
||||
```bash
|
||||
# 错误: 未编码
|
||||
@@ -82,7 +82,7 @@ curl "http://localhost:9000/bucket/%E8%B7%AF%E5%BE%84%20%E5%8C%85%E5%90%AB%20%E7
|
||||
|
||||
### ✅ 已完成
|
||||
|
||||
1. **后端验证**: 添加了控制字符验证(拒绝空字节、换行符)
|
||||
1. **后端验证**: 添加了控制字符验证 (拒绝空字节、换行符)
|
||||
2. **调试日志**: 为包含特殊字符的键添加了日志记录
|
||||
3. **测试**: 创建了综合 e2e 测试套件
|
||||
4. **文档**:
|
||||
@@ -103,7 +103,7 @@ curl "http://localhost:9000/bucket/%E8%B7%AF%E5%BE%84%20%E5%8C%85%E5%90%AB%20%E7
|
||||
3. **用户沟通**:
|
||||
- 更新用户文档
|
||||
- 在 FAQ 中添加故障排除
|
||||
- 传达已知的 UI 限制(如有)
|
||||
- 传达已知的 UI 限制 (如有)
|
||||
|
||||
## 测试
|
||||
|
||||
@@ -134,15 +134,15 @@ aws --endpoint-url=http://localhost:9000 s3 ls "s3://bucket/测试 包含 空格
|
||||
|
||||
## 支持
|
||||
|
||||
如果遇到特殊字符问题:
|
||||
如果遇到特殊字符问题:
|
||||
|
||||
1. **首先**: 查看[客户端指南](./client-special-characters-guide.md)
|
||||
2. **尝试**: 使用 mc 或 AWS CLI 隔离问题
|
||||
3. **启用**: 调试日志: `RUST_LOG=rustfs=debug`
|
||||
4. **报告**: 创建问题,包含:
|
||||
3. **启用**: 调试日志:`RUST_LOG=rustfs=debug`
|
||||
4. **报告**: 创建问题,包含:
|
||||
- 使用的客户端/SDK
|
||||
- 导致问题的确切对象名称
|
||||
- mc 是否工作(以隔离后端与客户端)
|
||||
- mc 是否工作 (以隔离后端与客户端)
|
||||
- 调试日志
|
||||
|
||||
## 相关文档
|
||||
@@ -154,26 +154,26 @@ aws --endpoint-url=http://localhost:9000 s3 ls "s3://bucket/测试 包含 空格
|
||||
|
||||
## 常见问题
|
||||
|
||||
**问: 可以在对象名称中使用空格吗?**
|
||||
答: 可以,但请使用能自动处理编码的 S3 SDK。
|
||||
**问:可以在对象名称中使用空格吗?**
|
||||
答:可以,但请使用能自动处理编码的 S3 SDK。
|
||||
|
||||
**问: 为什么 `+` 不能用作空格?**
|
||||
答: 在 URL 路径中,`+` 表示字面加号。只有在查询参数中 `+` 才表示空格。在路径中使用 `%20` 表示空格。
|
||||
**问:为什么 `+` 不能用作空格?**
|
||||
答:在 URL 路径中,`+` 表示字面加号。只有在查询参数中 `+` 才表示空格。在路径中使用 `%20` 表示空格。
|
||||
|
||||
**问: RustFS 支持对象名称中的 Unicode 吗?**
|
||||
答: 支持,对象名称是 UTF-8 字符串。它们支持任何有效的 UTF-8 字符。
|
||||
**问:RustFS 支持对象名称中的 Unicode 吗?**
|
||||
答:支持,对象名称是 UTF-8 字符串。它们支持任何有效的 UTF-8 字符。
|
||||
|
||||
**问: 哪些字符是禁止的?**
|
||||
答: 控制字符(空字节、换行符、回车符)被拒绝。所有可打印字符都是允许的。
|
||||
**问:哪些字符是禁止的?**
|
||||
答:控制字符 (空字节、换行符、回车符) 被拒绝。所有可打印字符都是允许的。
|
||||
|
||||
**问: 如何修复"UI 无法列出文件夹"的问题?**
|
||||
答: 使用 CLI(mc 或 aws-cli)代替。这是 UI 错误,不是后端问题。
|
||||
**问:如何修复"UI 无法列出文件夹"的问题?**
|
||||
答:使用 CLI(mc 或 aws-cli) 代替。这是 UI 错误,不是后端问题。
|
||||
|
||||
## 版本历史
|
||||
|
||||
- **v1.0** (2025-12-09): 初始文档
|
||||
- 完成综合分析
|
||||
- 确定根本原因(UI/客户端问题)
|
||||
- 确定根本原因 (UI/客户端问题)
|
||||
- 添加后端验证和日志
|
||||
- 创建客户端指南
|
||||
- 添加 E2E 测试
|
||||
|
||||
Reference in New Issue
Block a user