mirror of
https://github.com/rustfs/rustfs.git
synced 2026-03-17 14:24:08 +00:00
fix(utils): harden panic-prone paths (#2113)
Signed-off-by: houseme <housemecn@gmail.com> Co-authored-by: houseme <housemecn@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: heihutu <30542132+heihutu@users.noreply.github.com>
This commit is contained in:
@@ -93,7 +93,11 @@ pub fn build_webpki_client_verifier(tls_path: &str) -> io::Result<Option<Arc<dyn
|
||||
))
|
||||
})?;
|
||||
|
||||
let der_list = load_cert_bundle_der_bytes(ca_path.to_str().unwrap_or_default())?;
|
||||
let ca_path = ca_path
|
||||
.to_str()
|
||||
.ok_or_else(|| Error::other(format!("Invalid UTF-8 in mTLS CA path: {ca_path:?}")))?;
|
||||
|
||||
let der_list = load_cert_bundle_der_bytes(ca_path)?;
|
||||
|
||||
let mut store = RootCertStore::empty();
|
||||
for der in der_list {
|
||||
@@ -219,7 +223,23 @@ pub fn load_all_certs_from_directory(
|
||||
|
||||
if cert_path.exists() && key_path.exists() {
|
||||
debug!("find the domain name certificate: {} in {:?}", domain_name, cert_path);
|
||||
match load_cert_key_pair(cert_path.to_str().unwrap(), key_path.to_str().unwrap()) {
|
||||
let cert_path = match cert_path.to_str() {
|
||||
Some(path) => path,
|
||||
None => {
|
||||
warn!("skip domain certificate load, invalid UTF-8 path: {:?}", cert_path);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let key_path = match key_path.to_str() {
|
||||
Some(path) => path,
|
||||
None => {
|
||||
warn!("skip domain key load, invalid UTF-8 path: {:?}", key_path);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match load_cert_key_pair(cert_path, key_path) {
|
||||
Ok((certs, key)) => {
|
||||
cert_key_pairs.insert(domain_name.to_string(), (certs, key));
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fmt::Display,
|
||||
io::Error,
|
||||
net::{IpAddr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
|
||||
sync::{Arc, LazyLock, Mutex, RwLock},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
@@ -26,7 +26,7 @@ use tracing::{error, info};
|
||||
use transform_stream::AsyncTryStream;
|
||||
use url::{Host, Url};
|
||||
|
||||
static LOCAL_IPS: LazyLock<Vec<IpAddr>> = LazyLock::new(|| must_get_local_ips().unwrap());
|
||||
static LOCAL_IPS: LazyLock<Vec<IpAddr>> = LazyLock::new(get_local_ips_with_fallback);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DnsCacheEntry {
|
||||
@@ -53,7 +53,7 @@ type DynDnsResolver = dyn Fn(&str) -> std::io::Result<HashSet<IpAddr>> + Send +
|
||||
static CUSTOM_DNS_RESOLVER: LazyLock<RwLock<Option<Arc<DynDnsResolver>>>> = LazyLock::new(|| RwLock::new(None));
|
||||
|
||||
fn resolve_domain(domain: &str) -> std::io::Result<HashSet<IpAddr>> {
|
||||
if let Some(resolver) = CUSTOM_DNS_RESOLVER.read().unwrap().clone() {
|
||||
if let Some(resolver) = get_custom_dns_resolver() {
|
||||
return resolver(domain);
|
||||
}
|
||||
|
||||
@@ -164,6 +164,28 @@ pub fn is_local_host(host: Host<&str>, port: u16, local_port: u16) -> std::io::R
|
||||
Ok(is_local_host)
|
||||
}
|
||||
|
||||
fn get_custom_dns_resolver() -> Option<Arc<DynDnsResolver>> {
|
||||
match CUSTOM_DNS_RESOLVER.read() {
|
||||
Ok(guard) => guard.clone(),
|
||||
Err(poisoned) => {
|
||||
error!("CUSTOM_DNS_RESOLVER RwLock is poisoned; using resolver value despite poisoning");
|
||||
let guard = poisoned.into_inner();
|
||||
guard.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_custom_dns_resolver() -> bool {
|
||||
get_custom_dns_resolver().is_some()
|
||||
}
|
||||
|
||||
fn get_local_ips_with_fallback() -> Vec<IpAddr> {
|
||||
match must_get_local_ips() {
|
||||
Ok(ips) if !ips.is_empty() => ips,
|
||||
_ => vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), IpAddr::V6(Ipv6Addr::LOCALHOST)],
|
||||
}
|
||||
}
|
||||
|
||||
/// returns IP address of given host using layered DNS resolution.
|
||||
///
|
||||
/// This is the async version of `get_host_ip()` that provides enhanced DNS resolution
|
||||
@@ -172,7 +194,7 @@ pub async fn get_host_ip(host: Host<&str>) -> std::io::Result<HashSet<IpAddr>> {
|
||||
match host {
|
||||
Host::Domain(domain) => {
|
||||
// Check cache first
|
||||
if CUSTOM_DNS_RESOLVER.read().unwrap().is_none()
|
||||
if !has_custom_dns_resolver()
|
||||
&& let Ok(mut cache) = DNS_CACHE.lock()
|
||||
&& let Some(entry) = cache.get(domain)
|
||||
{
|
||||
@@ -188,7 +210,7 @@ pub async fn get_host_ip(host: Host<&str>) -> std::io::Result<HashSet<IpAddr>> {
|
||||
// Fallback to standard resolution when DNS resolver is not available
|
||||
match resolve_domain(domain) {
|
||||
Ok(ips) => {
|
||||
if CUSTOM_DNS_RESOLVER.read().unwrap().is_none() {
|
||||
if !has_custom_dns_resolver() {
|
||||
// Cache the result
|
||||
if let Ok(mut cache) = DNS_CACHE.lock() {
|
||||
cache.insert(domain.to_string(), DnsCacheEntry::new(ips.clone()));
|
||||
@@ -213,7 +235,16 @@ pub async fn get_host_ip(host: Host<&str>) -> std::io::Result<HashSet<IpAddr>> {
|
||||
}
|
||||
|
||||
pub fn get_available_port() -> u16 {
|
||||
TcpListener::bind("0.0.0.0:0").unwrap().local_addr().unwrap().port()
|
||||
try_get_available_port().unwrap_or_default()
|
||||
}
|
||||
|
||||
fn try_get_available_port() -> std::io::Result<u16> {
|
||||
let listener =
|
||||
TcpListener::bind("0.0.0.0:0").map_err(|err| Error::other(format!("Failed to bind for ephemeral port: {err}")))?;
|
||||
listener
|
||||
.local_addr()
|
||||
.map(|addr| addr.port())
|
||||
.map_err(|err| Error::other(format!("Failed to read ephemeral port: {err}")))
|
||||
}
|
||||
|
||||
/// returns IPs of local interface
|
||||
@@ -297,7 +328,7 @@ pub fn parse_and_resolve_address(addr_str: &str) -> std::io::Result<SocketAddr>
|
||||
.parse()
|
||||
.map_err(|e| Error::other(format!("Invalid port format: {addr_str}, err:{e:?}")))?;
|
||||
let final_port = if port == 0 {
|
||||
get_available_port() // assume get_available_port is available here
|
||||
try_get_available_port()? // assume get_available_port is available here
|
||||
} else {
|
||||
port
|
||||
};
|
||||
@@ -305,7 +336,7 @@ pub fn parse_and_resolve_address(addr_str: &str) -> std::io::Result<SocketAddr>
|
||||
} else {
|
||||
let mut addr = check_local_server_addr(addr_str)?; // assume check_local_server_addr is available here
|
||||
if addr.port() == 0 {
|
||||
addr.set_port(get_available_port());
|
||||
addr.set_port(try_get_available_port()?);
|
||||
}
|
||||
addr
|
||||
};
|
||||
|
||||
@@ -200,8 +200,8 @@ impl std::fmt::Display for ParsedURL {
|
||||
&& let Some(port) = url.port()
|
||||
&& ((url.scheme() == "http" && port == 80) || (url.scheme() == "https" && port == 443))
|
||||
{
|
||||
url.set_host(Some(&host)).unwrap();
|
||||
url.set_port(None).unwrap();
|
||||
let _ = url.set_host(Some(&host));
|
||||
let _ = url.set_port(None);
|
||||
}
|
||||
let mut s = url.to_string();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user