fix(utils): harden panic-prone paths (#2113)

Signed-off-by: houseme <housemecn@gmail.com>
Co-authored-by: houseme <housemecn@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: heihutu <30542132+heihutu@users.noreply.github.com>
This commit is contained in:
安正超
2026-03-11 15:16:03 +08:00
committed by GitHub
parent 9908a44c38
commit 7d7e0b2654
3 changed files with 63 additions and 12 deletions

View File

@@ -93,7 +93,11 @@ pub fn build_webpki_client_verifier(tls_path: &str) -> io::Result<Option<Arc<dyn
))
})?;
let der_list = load_cert_bundle_der_bytes(ca_path.to_str().unwrap_or_default())?;
let ca_path = ca_path
.to_str()
.ok_or_else(|| Error::other(format!("Invalid UTF-8 in mTLS CA path: {ca_path:?}")))?;
let der_list = load_cert_bundle_der_bytes(ca_path)?;
let mut store = RootCertStore::empty();
for der in der_list {
@@ -219,7 +223,23 @@ pub fn load_all_certs_from_directory(
if cert_path.exists() && key_path.exists() {
debug!("find the domain name certificate: {} in {:?}", domain_name, cert_path);
match load_cert_key_pair(cert_path.to_str().unwrap(), key_path.to_str().unwrap()) {
let cert_path = match cert_path.to_str() {
Some(path) => path,
None => {
warn!("skip domain certificate load, invalid UTF-8 path: {:?}", cert_path);
continue;
}
};
let key_path = match key_path.to_str() {
Some(path) => path,
None => {
warn!("skip domain key load, invalid UTF-8 path: {:?}", key_path);
continue;
}
};
match load_cert_key_pair(cert_path, key_path) {
Ok((certs, key)) => {
cert_key_pairs.insert(domain_name.to_string(), (certs, key));
}

View File

@@ -18,7 +18,7 @@ use std::{
collections::{HashMap, HashSet},
fmt::Display,
io::Error,
net::{IpAddr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
sync::{Arc, LazyLock, Mutex, RwLock},
time::{Duration, Instant},
};
@@ -26,7 +26,7 @@ use tracing::{error, info};
use transform_stream::AsyncTryStream;
use url::{Host, Url};
static LOCAL_IPS: LazyLock<Vec<IpAddr>> = LazyLock::new(|| must_get_local_ips().unwrap());
static LOCAL_IPS: LazyLock<Vec<IpAddr>> = LazyLock::new(get_local_ips_with_fallback);
#[derive(Debug, Clone)]
struct DnsCacheEntry {
@@ -53,7 +53,7 @@ type DynDnsResolver = dyn Fn(&str) -> std::io::Result<HashSet<IpAddr>> + Send +
static CUSTOM_DNS_RESOLVER: LazyLock<RwLock<Option<Arc<DynDnsResolver>>>> = LazyLock::new(|| RwLock::new(None));
fn resolve_domain(domain: &str) -> std::io::Result<HashSet<IpAddr>> {
if let Some(resolver) = CUSTOM_DNS_RESOLVER.read().unwrap().clone() {
if let Some(resolver) = get_custom_dns_resolver() {
return resolver(domain);
}
@@ -164,6 +164,28 @@ pub fn is_local_host(host: Host<&str>, port: u16, local_port: u16) -> std::io::R
Ok(is_local_host)
}
fn get_custom_dns_resolver() -> Option<Arc<DynDnsResolver>> {
match CUSTOM_DNS_RESOLVER.read() {
Ok(guard) => guard.clone(),
Err(poisoned) => {
error!("CUSTOM_DNS_RESOLVER RwLock is poisoned; using resolver value despite poisoning");
let guard = poisoned.into_inner();
guard.clone()
}
}
}
fn has_custom_dns_resolver() -> bool {
get_custom_dns_resolver().is_some()
}
fn get_local_ips_with_fallback() -> Vec<IpAddr> {
match must_get_local_ips() {
Ok(ips) if !ips.is_empty() => ips,
_ => vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), IpAddr::V6(Ipv6Addr::LOCALHOST)],
}
}
/// returns IP address of given host using layered DNS resolution.
///
/// This is the async version of `get_host_ip()` that provides enhanced DNS resolution
@@ -172,7 +194,7 @@ pub async fn get_host_ip(host: Host<&str>) -> std::io::Result<HashSet<IpAddr>> {
match host {
Host::Domain(domain) => {
// Check cache first
if CUSTOM_DNS_RESOLVER.read().unwrap().is_none()
if !has_custom_dns_resolver()
&& let Ok(mut cache) = DNS_CACHE.lock()
&& let Some(entry) = cache.get(domain)
{
@@ -188,7 +210,7 @@ pub async fn get_host_ip(host: Host<&str>) -> std::io::Result<HashSet<IpAddr>> {
// Fallback to standard resolution when DNS resolver is not available
match resolve_domain(domain) {
Ok(ips) => {
if CUSTOM_DNS_RESOLVER.read().unwrap().is_none() {
if !has_custom_dns_resolver() {
// Cache the result
if let Ok(mut cache) = DNS_CACHE.lock() {
cache.insert(domain.to_string(), DnsCacheEntry::new(ips.clone()));
@@ -213,7 +235,16 @@ pub async fn get_host_ip(host: Host<&str>) -> std::io::Result<HashSet<IpAddr>> {
}
pub fn get_available_port() -> u16 {
TcpListener::bind("0.0.0.0:0").unwrap().local_addr().unwrap().port()
try_get_available_port().unwrap_or_default()
}
fn try_get_available_port() -> std::io::Result<u16> {
let listener =
TcpListener::bind("0.0.0.0:0").map_err(|err| Error::other(format!("Failed to bind for ephemeral port: {err}")))?;
listener
.local_addr()
.map(|addr| addr.port())
.map_err(|err| Error::other(format!("Failed to read ephemeral port: {err}")))
}
/// returns IPs of local interface
@@ -297,7 +328,7 @@ pub fn parse_and_resolve_address(addr_str: &str) -> std::io::Result<SocketAddr>
.parse()
.map_err(|e| Error::other(format!("Invalid port format: {addr_str}, err:{e:?}")))?;
let final_port = if port == 0 {
get_available_port() // assume get_available_port is available here
try_get_available_port()? // assume get_available_port is available here
} else {
port
};
@@ -305,7 +336,7 @@ pub fn parse_and_resolve_address(addr_str: &str) -> std::io::Result<SocketAddr>
} else {
let mut addr = check_local_server_addr(addr_str)?; // assume check_local_server_addr is available here
if addr.port() == 0 {
addr.set_port(get_available_port());
addr.set_port(try_get_available_port()?);
}
addr
};

View File

@@ -200,8 +200,8 @@ impl std::fmt::Display for ParsedURL {
&& let Some(port) = url.port()
&& ((url.scheme() == "http" && port == 80) || (url.scheme() == "https" && port == 443))
{
url.set_host(Some(&host)).unwrap();
url.set_port(None).unwrap();
let _ = url.set_host(Some(&host));
let _ = url.set_port(None);
}
let mut s = url.to_string();