From 8f310cd4a8ed32cce7c9630929bc70e503448b7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=89=E6=AD=A3=E8=B6=85?= Date: Wed, 15 Oct 2025 21:24:03 +0800 Subject: [PATCH] test: allow mocking dns resolver (#656) --- crates/utils/src/net.rs | 112 ++++++++++++++++++++++++++++------------ 1 file changed, 80 insertions(+), 32 deletions(-) diff --git a/crates/utils/src/net.rs b/crates/utils/src/net.rs index 9c0f3bc6..51f1220d 100644 --- a/crates/utils/src/net.rs +++ b/crates/utils/src/net.rs @@ -16,14 +16,16 @@ use bytes::Bytes; use futures::pin_mut; use futures::{Stream, StreamExt}; use std::io::Error; -use std::net::Ipv6Addr; -use std::sync::{LazyLock, Mutex}; use std::{ collections::{HashMap, HashSet}, fmt::Display, net::{IpAddr, SocketAddr, TcpListener, ToSocketAddrs}, time::{Duration, Instant}, }; +use std::{ + net::Ipv6Addr, + sync::{Arc, LazyLock, Mutex, RwLock}, +}; use tracing::{error, info}; use transform_stream::AsyncTryStream; use url::{Host, Url}; @@ -51,6 +53,41 @@ impl DnsCacheEntry { static DNS_CACHE: LazyLock>> = LazyLock::new(|| Mutex::new(HashMap::new())); const DNS_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes +type DynDnsResolver = dyn Fn(&str) -> std::io::Result> + Send + Sync + 'static; +static CUSTOM_DNS_RESOLVER: LazyLock>>> = LazyLock::new(|| RwLock::new(None)); + +fn resolve_domain(domain: &str) -> std::io::Result> { + if let Some(resolver) = CUSTOM_DNS_RESOLVER.read().unwrap().clone() { + return resolver(domain); + } + + (domain, 0) + .to_socket_addrs() + .map(|v| v.map(|v| v.ip()).collect::>()) + .map_err(Error::other) +} + +#[cfg(test)] +fn clear_dns_cache() { + if let Ok(mut cache) = DNS_CACHE.lock() { + cache.clear(); + } +} + +#[cfg(test)] +pub fn set_mock_dns_resolver(resolver: F) +where + F: Fn(&str) -> std::io::Result> + Send + Sync + 'static, +{ + *CUSTOM_DNS_RESOLVER.write().unwrap() = Some(Arc::new(resolver)); + clear_dns_cache(); +} + +#[cfg(test)] +pub fn reset_dns_resolver() { + *CUSTOM_DNS_RESOLVER.write().unwrap() = None; + clear_dns_cache(); +} /// helper for validating if the provided arg is an ip address. pub fn is_socket_addr(addr: &str) -> bool { @@ -93,10 +130,7 @@ pub fn is_local_host(host: Host<&str>, port: u16, local_port: u16) -> std::io::R let local_set: HashSet = LOCAL_IPS.iter().copied().collect(); let is_local_host = match host { Host::Domain(domain) => { - let ips = match (domain, 0).to_socket_addrs().map(|v| v.map(|v| v.ip()).collect::>()) { - Ok(ips) => ips, - Err(err) => return Err(Error::other(err)), - }; + let ips = resolve_domain(domain)?.into_iter().collect::>(); ips.iter().any(|ip| local_set.contains(ip)) } @@ -130,30 +164,31 @@ pub async fn get_host_ip(host: Host<&str>) -> std::io::Result> { // } // } // Check cache first - if let Ok(mut cache) = DNS_CACHE.lock() { - if let Some(entry) = cache.get(domain) { - if !entry.is_expired(DNS_CACHE_TTL) { - return Ok(entry.ips.clone()); + if CUSTOM_DNS_RESOLVER.read().unwrap().is_none() { + if let Ok(mut cache) = DNS_CACHE.lock() { + if let Some(entry) = cache.get(domain) { + if !entry.is_expired(DNS_CACHE_TTL) { + return Ok(entry.ips.clone()); + } + // Remove expired entry + cache.remove(domain); } - // Remove expired entry - cache.remove(domain); } } info!("Cache miss for domain {domain}, querying system resolver."); // Fallback to standard resolution when DNS resolver is not available - match (domain, 0) - .to_socket_addrs() - .map(|v| v.map(|v| v.ip()).collect::>()) - { + match resolve_domain(domain) { Ok(ips) => { - // Cache the result - if let Ok(mut cache) = DNS_CACHE.lock() { - cache.insert(domain.to_string(), DnsCacheEntry::new(ips.clone())); - // Limit cache size to prevent memory bloat - if cache.len() > 1000 { - cache.retain(|_, v| !v.is_expired(DNS_CACHE_TTL)); + if CUSTOM_DNS_RESOLVER.read().unwrap().is_none() { + // Cache the result + if let Ok(mut cache) = DNS_CACHE.lock() { + cache.insert(domain.to_string(), DnsCacheEntry::new(ips.clone())); + // Limit cache size to prevent memory bloat + if cache.len() > 1000 { + cache.retain(|_, v| !v.is_expired(DNS_CACHE_TTL)); + } } } info!("System query for domain {domain}: {:?}", ips); @@ -292,6 +327,21 @@ mod test { use super::*; use crate::init_global_dns_resolver; use std::net::{Ipv4Addr, Ipv6Addr}; + use std::{collections::HashSet, io::Error as IoError}; + + fn mock_resolver(domain: &str) -> std::io::Result> { + match domain { + "localhost" => Ok([ + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + ] + .into_iter() + .collect()), + "example.org" => Ok([IpAddr::V4(Ipv4Addr::new(192, 0, 2, 10))].into_iter().collect()), + "invalid.nonexistent.domain.example" => Err(IoError::other("mock DNS failure")), + _ => Ok(HashSet::new()), + } + } #[test] fn test_is_socket_addr() { @@ -349,7 +399,7 @@ mod test { let invalid_cases = [ ("localhost", "invalid socket address"), ("", "invalid socket address"), - ("example.org:54321", "host in server address should be this server"), + ("203.0.113.1:54321", "host in server address should be this server"), ("8.8.8.8:53", "host in server address should be this server"), (":-10", "invalid port value"), ("invalid:port", "invalid port value"), @@ -369,6 +419,8 @@ mod test { #[test] fn test_is_local_host() { + set_mock_dns_resolver(mock_resolver); + // Test localhost domain let localhost_host = Host::Domain("localhost"); assert!(is_local_host(localhost_host, 0, 0).unwrap()); @@ -393,10 +445,13 @@ mod test { // Test invalid domain should return error let invalid_host = Host::Domain("invalid.nonexistent.domain.example"); assert!(is_local_host(invalid_host, 0, 0).is_err()); + + reset_dns_resolver(); } #[tokio::test] async fn test_get_host_ip() { + set_mock_dns_resolver(mock_resolver); match init_global_dns_resolver().await { Ok(_) => {} Err(e) => { @@ -427,16 +482,9 @@ mod test { // Test invalid domain let invalid_host = Host::Domain("invalid.nonexistent.domain.example"); - match get_host_ip(invalid_host.clone()).await { - Ok(ips) => { - // Depending on DNS resolver behavior, it might return empty set or error - assert!(ips.is_empty(), "Expected empty IP set for invalid domain, got: {ips:?}"); - } - Err(_) => { - error!("Expected error for invalid domain"); - } // Expected error - } assert!(get_host_ip(invalid_host).await.is_err()); + + reset_dns_resolver(); } #[test]