diff --git a/crates/signer/src/request_signature_v2.rs b/crates/signer/src/request_signature_v2.rs index 64c4407e..300ed9ae 100644 --- a/crates/signer/src/request_signature_v2.rs +++ b/crates/signer/src/request_signature_v2.rs @@ -54,7 +54,8 @@ pub fn pre_sign_v2( let string_to_sign = pre_string_to_sign_v2(&req, virtual_host); let signature = hex(hmac_sha1(secret_access_key, string_to_sign)); - let result = serde_urlencoded::from_str::>(req.uri().query().unwrap()); + let query_source = req.uri().query().unwrap_or(""); + let result = serde_urlencoded::from_str::>(query_source); let mut query = result.unwrap_or_default(); if get_host_addr(&req).contains(".storage.googleapis.com") { query.insert("GoogleAccessId".to_string(), access_key_id.to_string()); @@ -95,21 +96,23 @@ pub fn sign_v2( let d = OffsetDateTime::now_utc(); let d2 = d.replace_time(time::Time::from_hms(0, 0, 0).unwrap()); + { + let headers = req.headers_mut(); + let need_default_date = headers.get("Date").and_then(|v| v.to_str().ok()).is_none_or(|v| v.is_empty()); + if need_default_date { + headers.insert( + "Date", + d2.format(&format_description::well_known::Rfc2822) + .unwrap() + .to_string() + .parse() + .unwrap(), + ); + } + } let string_to_sign = string_to_sign_v2(&req, virtual_host); let headers = req.headers_mut(); - let date = headers.get("Date").unwrap(); - if date.to_str().unwrap() == "" { - headers.insert( - "Date", - d2.format(&format_description::well_known::Rfc2822) - .unwrap() - .to_string() - .parse() - .unwrap(), - ); - } - let auth_header = format!("{SIGN_V2_ALGORITHM} {access_key_id}:"); let auth_header = format!( "{}{}", @@ -133,11 +136,11 @@ fn pre_string_to_sign_v2(req: &request::Request, virtual_host: bool) -> St fn write_pre_sign_v2_headers(buf: &mut BytesMut, req: &request::Request) { let _ = buf.write_str(req.method().as_str()); let _ = buf.write_char('\n'); - let _ = buf.write_str(req.headers().get("Content-Md5").unwrap().to_str().unwrap()); + let _ = buf.write_str(req.headers().get("Content-Md5").and_then(|v| v.to_str().ok()).unwrap_or("")); let _ = buf.write_char('\n'); - let _ = buf.write_str(req.headers().get("Content-Type").unwrap().to_str().unwrap()); + let _ = buf.write_str(req.headers().get("Content-Type").and_then(|v| v.to_str().ok()).unwrap_or("")); let _ = buf.write_char('\n'); - let _ = buf.write_str(req.headers().get("Expires").unwrap().to_str().unwrap()); + let _ = buf.write_str(req.headers().get("Expires").and_then(|v| v.to_str().ok()).unwrap_or("")); let _ = buf.write_char('\n'); } @@ -153,11 +156,11 @@ fn write_sign_v2_headers(buf: &mut BytesMut, req: &request::Request) { let headers = req.headers(); let _ = buf.write_str(req.method().as_str()); let _ = buf.write_char('\n'); - let _ = buf.write_str(headers.get("Content-Md5").unwrap().to_str().unwrap()); + let _ = buf.write_str(headers.get("Content-Md5").and_then(|v| v.to_str().ok()).unwrap_or("")); let _ = buf.write_char('\n'); - let _ = buf.write_str(headers.get("Content-Type").unwrap().to_str().unwrap()); + let _ = buf.write_str(headers.get("Content-Type").and_then(|v| v.to_str().ok()).unwrap_or("")); let _ = buf.write_char('\n'); - let _ = buf.write_str(headers.get("Date").unwrap().to_str().unwrap()); + let _ = buf.write_str(headers.get("Date").and_then(|v| v.to_str().ok()).unwrap_or("")); let _ = buf.write_char('\n'); } @@ -172,7 +175,7 @@ fn write_canonicalized_headers(buf: &mut BytesMut, req: &request::Request) .headers() .get_all(k) .iter() - .map(|e| e.to_str().unwrap().to_string()) + .filter_map(|e| e.to_str().ok().map(ToString::to_string)) .collect(); vals.insert(lk, vv); } @@ -218,28 +221,107 @@ const INCLUDED_QUERY: &[&str] = &[ fn write_canonicalized_resource(buf: &mut BytesMut, req: &request::Request, virtual_host: bool) { let request_url = req.uri(); let _ = buf.write_str(&encode_url2path(req, virtual_host)); - if request_url.query().unwrap() != "" { - let mut n: i64 = 0; - let result = serde_urlencoded::from_str::>>(req.uri().query().unwrap()); - let vals = result.unwrap_or_default(); + if let Some(query_str) = request_url.query().filter(|query| !query.is_empty()) { + let mut query_vals = HashMap::new(); + for pair in query_str.split('&') { + let mut iter = pair.splitn(2, '='); + let key = match iter.next() { + Some(k) if !k.is_empty() => k, + _ => continue, + }; + let value = iter.next().unwrap_or(""); + query_vals.insert(key.to_string(), value.to_string()); + } + + let mut canonical = Vec::::new(); for resource in INCLUDED_QUERY { - let vv = &vals[*resource]; - if !vv.is_empty() { - n += 1; - match n { - 1 => { - let _ = buf.write_char('?'); - } - _ => { - let _ = buf.write_char('&'); - let _ = buf.write_str(resource); - if !vv[0].is_empty() { - let _ = buf.write_char('='); - let _ = buf.write_str(&vv[0]); - } - } + if let Some(value) = query_vals.get(*resource) { + let mut item = resource.to_string(); + if !value.is_empty() { + let _ = write!(&mut item, "={value}"); } + canonical.push(item); } } + if !canonical.is_empty() { + let _ = buf.write_char('?'); + let _ = buf.write_str(&canonical.join("&")); + } + } +} + +#[cfg(test)] +#[allow(unused_variables, unused_mut)] +mod tests { + use std::collections::HashMap; + + use super::*; + + #[test] + fn test_pre_sign_v2_without_query_should_keep_safe() { + let mut req = request::Request::builder() + .method(http::Method::GET) + .uri("http://examplebucket.s3.amazonaws.com/object") + .body(Body::empty()) + .unwrap(); + req.headers_mut() + .insert("host", "examplebucket.s3.amazonaws.com".parse().unwrap()); + + let req = pre_sign_v2(req, "AKIAEXAMPLE", "SECRET", 60, false); + let query = req.uri().query().unwrap_or_default(); + let values = serde_urlencoded::from_str::>(query).unwrap(); + + assert!(values.contains_key("AWSAccessKeyId")); + assert!(values.contains_key("Expires")); + assert!(values.contains_key("Signature")); + } + + #[test] + fn test_sign_v2_missing_optional_headers() { + let mut req = request::Request::builder() + .method(http::Method::GET) + .uri("http://examplebucket.s3.amazonaws.com/object") + .body(Body::empty()) + .unwrap(); + req.headers_mut() + .insert("host", "examplebucket.s3.amazonaws.com".parse().unwrap()); + + let req = sign_v2(req, 0, "AKIAEXAMPLE", "SECRET", false); + assert!(req.headers().get("Date").is_some()); + assert!(req.headers().get("Authorization").is_some()); + } + + #[test] + fn test_write_canonicalized_resource_with_single_query_param() { + let mut req = request::Request::builder() + .method(http::Method::GET) + .uri("http://examplebucket.s3.amazonaws.com/object?acl") + .body(Body::empty()) + .unwrap(); + req.headers_mut() + .insert("host", "examplebucket.s3.amazonaws.com".parse().unwrap()); + let mut buf = BytesMut::new(); + write_canonicalized_resource(&mut buf, &req, false); + assert_eq!(String::from_utf8(buf.to_vec()).unwrap(), "/object?acl"); + } + + #[test] + fn test_sign_v2_signature_matches_injected_date() { + let mut req = request::Request::builder() + .method(http::Method::GET) + .uri("http://examplebucket.s3.amazonaws.com/object") + .body(Body::empty()) + .unwrap(); + req.headers_mut() + .insert("host", "examplebucket.s3.amazonaws.com".parse().unwrap()); + + let req = sign_v2(req, 0, "AKIAEXAMPLE", "SECRET", false); + let expected_string_to_sign = string_to_sign_v2(&req, false); + let expected_signature = base64_simd::URL_SAFE_NO_PAD.encode_to_string(hmac_sha1("SECRET", expected_string_to_sign)); + + assert_eq!( + req.headers().get("Authorization").unwrap().to_str().unwrap(), + format!("AWS AKIAEXAMPLE:{expected_signature}") + ); } }