mirror of
https://github.com/rustfs/rustfs.git
synced 2026-01-17 01:30:33 +00:00
Fix CRC32C Checksum Implementation and Enhance Authentication System (#678)
* fix: get_condition_values * fix checksum crc32c * fix clippy
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -6773,6 +6773,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"base64-simd",
|
||||
"bytes",
|
||||
"crc32c",
|
||||
"crc32fast",
|
||||
"crc64fast-nvme",
|
||||
"futures",
|
||||
@@ -6888,6 +6889,7 @@ dependencies = [
|
||||
"hickory-resolver",
|
||||
"highway",
|
||||
"hmac 0.12.1",
|
||||
"http 1.3.1",
|
||||
"hyper 1.7.0",
|
||||
"libc",
|
||||
"local-ip-address",
|
||||
|
||||
@@ -121,6 +121,7 @@ chrono = { version = "0.4.42", features = ["serde"] }
|
||||
clap = { version = "4.5.49", features = ["derive", "env"] }
|
||||
const-str = { version = "0.7.0", features = ["std", "proc"] }
|
||||
crc32fast = "1.5.0"
|
||||
crc32c = "0.6.8"
|
||||
crc64fast-nvme = "1.2.0"
|
||||
criterion = { version = "0.7", features = ["html_reports"] }
|
||||
crossbeam-queue = "0.3.12"
|
||||
|
||||
@@ -30,7 +30,7 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
rustfs-config = { workspace = true, features = ["constants","opa"] }
|
||||
tokio.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
time = { workspace = true, features = ["serde-human-readable"] }
|
||||
serde = { workspace = true, features = ["derive", "rc"] }
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -39,7 +39,7 @@ pub struct AuthZPlugin {
|
||||
fn check() -> Result<(), String> {
|
||||
let env_list = env::vars();
|
||||
let mut candidate = HashMap::new();
|
||||
let prefix = format!("{}{}", ENV_PREFIX, POLICY_PLUGIN_SUB_SYS).to_uppercase();
|
||||
let prefix = format!("{ENV_PREFIX}{POLICY_PLUGIN_SUB_SYS}").to_uppercase();
|
||||
for (key, value) in env_list {
|
||||
if key.starts_with(&prefix) {
|
||||
candidate.insert(key.to_string(), value);
|
||||
@@ -48,13 +48,13 @@ fn check() -> Result<(), String> {
|
||||
|
||||
//check required env vars
|
||||
if candidate.remove(ENV_POLICY_PLUGIN_OPA_URL).is_none() {
|
||||
return Err(format!("Missing required env var: {}", ENV_POLICY_PLUGIN_OPA_URL));
|
||||
return Err(format!("Missing required env var: {ENV_POLICY_PLUGIN_OPA_URL}"));
|
||||
}
|
||||
|
||||
// check optional env vars
|
||||
candidate.remove(ENV_POLICY_PLUGIN_AUTH_TOKEN);
|
||||
if !candidate.is_empty() {
|
||||
return Err(format!("Invalid env vars: {:?}", candidate));
|
||||
return Err(format!("Invalid env vars: {candidate:?}"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -73,7 +73,7 @@ async fn validate(config: &Args) -> Result<(), String> {
|
||||
};
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(format!("Error connecting to OPA: {}", err));
|
||||
return Err(format!("Error connecting to OPA: {err}"));
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
@@ -83,7 +83,7 @@ pub async fn lookup_config() -> Result<Args, String> {
|
||||
let args = Args::default();
|
||||
|
||||
let get_cfg =
|
||||
|cfg: &str| -> Result<String, String> { env::var(cfg).map_err(|e| format!("Error getting env var {}: {:?}", cfg, e)) };
|
||||
|cfg: &str| -> Result<String, String> { env::var(cfg).map_err(|e| format!("Error getting env var {cfg}: {e:?}")) };
|
||||
|
||||
let url = get_cfg(ENV_POLICY_PLUGIN_OPA_URL);
|
||||
if url.is_err() {
|
||||
|
||||
@@ -52,6 +52,7 @@ base64-simd.workspace = true
|
||||
crc64fast-nvme.workspace = true
|
||||
s3s.workspace = true
|
||||
hex-simd.workspace = true
|
||||
crc32c.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
|
||||
@@ -645,27 +645,18 @@ impl ChecksumHasher for Crc32IeeeHasher {
|
||||
}
|
||||
|
||||
/// CRC32 Castagnoli hasher
|
||||
pub struct Crc32CastagnoliHasher {
|
||||
hasher: crc32fast::Hasher,
|
||||
}
|
||||
|
||||
impl Default for Crc32CastagnoliHasher {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
#[derive(Default)]
|
||||
pub struct Crc32CastagnoliHasher(u32);
|
||||
|
||||
impl Crc32CastagnoliHasher {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
hasher: crc32fast::Hasher::new_with_initial(0),
|
||||
}
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for Crc32CastagnoliHasher {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.hasher.update(buf);
|
||||
self.0 = crc32c::crc32c_append(self.0, buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
@@ -676,11 +667,11 @@ impl Write for Crc32CastagnoliHasher {
|
||||
|
||||
impl ChecksumHasher for Crc32CastagnoliHasher {
|
||||
fn finalize(&mut self) -> Vec<u8> {
|
||||
self.hasher.clone().finalize().to_be_bytes().to_vec()
|
||||
self.0.to_be_bytes().to_vec()
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.hasher = crc32fast::Hasher::new_with_initial(0);
|
||||
self.0 = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -491,19 +491,22 @@ impl AsyncRead for HashReader {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let content_hash = hasher.finalize();
|
||||
if content_hash != expected_content_hash.raw {
|
||||
error!(
|
||||
"Content hash mismatch, expected={:?}, actual={:?}",
|
||||
hex_simd::encode_to_string(&expected_content_hash.raw, hex_simd::AsciiCase::Lower),
|
||||
hex_simd::encode_to_string(content_hash, hex_simd::AsciiCase::Lower)
|
||||
);
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Content hash mismatch",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let content_hash = hasher.finalize();
|
||||
|
||||
if content_hash != expected_content_hash.raw {
|
||||
error!(
|
||||
"Content hash mismatch, type={:?}, encoded={:?}, expected={:?}, actual={:?}",
|
||||
expected_content_hash.checksum_type,
|
||||
expected_content_hash.encoded,
|
||||
hex_simd::encode_to_string(&expected_content_hash.raw, hex_simd::AsciiCase::Lower),
|
||||
hex_simd::encode_to_string(content_hash, hex_simd::AsciiCase::Lower)
|
||||
);
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Content hash mismatch",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ hex-simd = { workspace = true, optional = true }
|
||||
highway = { workspace = true, optional = true }
|
||||
hickory-resolver = { workspace = true, optional = true }
|
||||
hmac = { workspace = true, optional = true }
|
||||
http = { workspace = true, optional = true }
|
||||
hyper = { workspace = true, optional = true }
|
||||
libc = { workspace = true, optional = true }
|
||||
local-ip-address = { workspace = true, optional = true }
|
||||
@@ -93,5 +94,5 @@ hash = ["dep:highway", "dep:md-5", "dep:sha2", "dep:blake3", "dep:serde", "dep:s
|
||||
os = ["dep:nix", "dep:tempfile", "winapi"] # operating system utilities
|
||||
integration = [] # integration test features
|
||||
sys = ["dep:sysinfo"] # system information features
|
||||
http = ["dep:convert_case"]
|
||||
http = ["dep:convert_case", "dep:http"]
|
||||
full = ["ip", "tls", "net", "io", "hash", "os", "integration", "path", "crypto", "string", "compress", "sys", "notify", "http"] # all features
|
||||
|
||||
201
crates/utils/src/http/ip.rs
Normal file
201
crates/utils/src/http/ip.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use http::HeaderMap;
|
||||
use regex::Regex;
|
||||
use std::env;
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// De-facto standard header keys.
|
||||
const X_FORWARDED_FOR: &str = "x-forwarded-for";
|
||||
const X_FORWARDED_PROTO: &str = "x-forwarded-proto";
|
||||
const X_FORWARDED_SCHEME: &str = "x-forwarded-scheme";
|
||||
const X_REAL_IP: &str = "x-real-ip";
|
||||
|
||||
/// RFC7239 defines a new "Forwarded: " header designed to replace the
|
||||
/// existing use of X-Forwarded-* headers.
|
||||
/// e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43
|
||||
const FORWARDED: &str = "forwarded";
|
||||
|
||||
static FOR_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)(?:for=)([^(;|,| )]+)(.*)").unwrap());
|
||||
static PROTO_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)^(;|,| )+(?:proto=)(https|http)").unwrap());
|
||||
|
||||
/// Used to disable all processing of the X-Forwarded-For header in source IP discovery.
|
||||
fn is_xff_header_enabled() -> bool {
|
||||
env::var("_RUSTFS_API_XFF_HEADER")
|
||||
.unwrap_or_else(|_| "on".to_string())
|
||||
.to_lowercase()
|
||||
== "on"
|
||||
}
|
||||
|
||||
/// GetSourceScheme retrieves the scheme from the X-Forwarded-Proto and RFC7239
|
||||
/// Forwarded headers (in that order).
|
||||
pub fn get_source_scheme(headers: &HeaderMap) -> Option<String> {
|
||||
// Retrieve the scheme from X-Forwarded-Proto.
|
||||
if let Some(proto) = headers.get(X_FORWARDED_PROTO) {
|
||||
if let Ok(proto_str) = proto.to_str() {
|
||||
return Some(proto_str.to_lowercase());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(proto) = headers.get(X_FORWARDED_SCHEME) {
|
||||
if let Ok(proto_str) = proto.to_str() {
|
||||
return Some(proto_str.to_lowercase());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(forwarded) = headers.get(FORWARDED) {
|
||||
if let Ok(forwarded_str) = forwarded.to_str() {
|
||||
// match should contain at least two elements if the protocol was
|
||||
// specified in the Forwarded header. The first element will always be
|
||||
// the 'for=', which we ignore, subsequently we proceed to look for
|
||||
// 'proto=' which should precede right after `for=` if not
|
||||
// we simply ignore the values and return empty. This is in line
|
||||
// with the approach we took for returning first ip from multiple
|
||||
// params.
|
||||
if let Some(for_match) = FOR_REGEX.captures(forwarded_str) {
|
||||
if for_match.len() > 1 {
|
||||
let remaining = &for_match[2];
|
||||
if let Some(proto_match) = PROTO_REGEX.captures(remaining) {
|
||||
if proto_match.len() > 1 {
|
||||
return Some(proto_match[2].to_lowercase());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// GetSourceIPFromHeaders retrieves the IP from the X-Forwarded-For, X-Real-IP
|
||||
/// and RFC7239 Forwarded headers (in that order)
|
||||
pub fn get_source_ip_from_headers(headers: &HeaderMap) -> Option<String> {
|
||||
let mut addr = None;
|
||||
|
||||
if is_xff_header_enabled() {
|
||||
if let Some(forwarded_for) = headers.get(X_FORWARDED_FOR) {
|
||||
if let Ok(forwarded_str) = forwarded_for.to_str() {
|
||||
// Only grab the first (client) address. Note that '192.168.0.1,
|
||||
// 10.1.1.1' is a valid key for X-Forwarded-For where addresses after
|
||||
// the first may represent forwarding proxies earlier in the chain.
|
||||
let first_comma = forwarded_str.find(", ");
|
||||
let end = first_comma.unwrap_or(forwarded_str.len());
|
||||
addr = Some(forwarded_str[..end].to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addr.is_none() {
|
||||
if let Some(real_ip) = headers.get(X_REAL_IP) {
|
||||
if let Ok(real_ip_str) = real_ip.to_str() {
|
||||
// X-Real-IP should only contain one IP address (the client making the
|
||||
// request).
|
||||
addr = Some(real_ip_str.to_string());
|
||||
}
|
||||
} else if let Some(forwarded) = headers.get(FORWARDED) {
|
||||
if let Ok(forwarded_str) = forwarded.to_str() {
|
||||
// match should contain at least two elements if the protocol was
|
||||
// specified in the Forwarded header. The first element will always be
|
||||
// the 'for=' capture, which we ignore. In the case of multiple IP
|
||||
// addresses (for=8.8.8.8, 8.8.4.4, 172.16.1.20 is valid) we only
|
||||
// extract the first, which should be the client IP.
|
||||
if let Some(for_match) = FOR_REGEX.captures(forwarded_str) {
|
||||
if for_match.len() > 1 {
|
||||
// IPv6 addresses in Forwarded headers are quoted-strings. We strip
|
||||
// these quotes.
|
||||
let ip = for_match[1].trim_matches('"');
|
||||
addr = Some(ip.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
addr
|
||||
}
|
||||
|
||||
/// GetSourceIPRaw retrieves the IP from the request headers
|
||||
/// and falls back to remote_addr when necessary.
|
||||
/// however returns without bracketing.
|
||||
pub fn get_source_ip_raw(headers: &HeaderMap, remote_addr: &str) -> String {
|
||||
let addr = get_source_ip_from_headers(headers).unwrap_or_else(|| remote_addr.to_string());
|
||||
|
||||
// Default to remote address if headers not set.
|
||||
if let Ok(socket_addr) = SocketAddr::from_str(&addr) {
|
||||
socket_addr.ip().to_string()
|
||||
} else {
|
||||
addr
|
||||
}
|
||||
}
|
||||
|
||||
/// GetSourceIP retrieves the IP from the request headers
|
||||
/// and falls back to remote_addr when necessary.
|
||||
pub fn get_source_ip(headers: &HeaderMap, remote_addr: &str) -> String {
|
||||
let addr = get_source_ip_raw(headers, remote_addr);
|
||||
if addr.contains(':') { format!("[{addr}]") } else { addr }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use http::HeaderValue;
|
||||
|
||||
fn create_test_headers() -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-forwarded-for", HeaderValue::from_static("192.168.1.1"));
|
||||
headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
|
||||
headers
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_source_scheme() {
|
||||
let headers = create_test_headers();
|
||||
assert_eq!(get_source_scheme(&headers), Some("https".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_source_ip_from_headers() {
|
||||
let headers = create_test_headers();
|
||||
assert_eq!(get_source_ip_from_headers(&headers), Some("192.168.1.1".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_source_ip_raw() {
|
||||
let headers = create_test_headers();
|
||||
let remote_addr = "127.0.0.1:8080";
|
||||
let result = get_source_ip_raw(&headers, remote_addr);
|
||||
assert_eq!(result, "192.168.1.1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_source_ip() {
|
||||
let headers = create_test_headers();
|
||||
let remote_addr = "127.0.0.1:8080";
|
||||
let result = get_source_ip(&headers, remote_addr);
|
||||
assert_eq!(result, "192.168.1.1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_source_ip_ipv6() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-forwarded-for", HeaderValue::from_static("2001:db8::1"));
|
||||
let remote_addr = "127.0.0.1:8080";
|
||||
let result = get_source_ip(&headers, remote_addr);
|
||||
assert_eq!(result, "[2001:db8::1]");
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,18 @@
|
||||
pub mod headers;
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod headers;
|
||||
pub mod ip;
|
||||
pub use headers::*;
|
||||
pub use ip::*;
|
||||
|
||||
@@ -42,7 +42,7 @@ async fn check_admin_request_auth(
|
||||
deny_only: bool,
|
||||
action: Action,
|
||||
) -> S3Result<()> {
|
||||
let conditions = get_condition_values(headers, cred);
|
||||
let conditions = get_condition_values(headers, cred, None, None);
|
||||
|
||||
if !iam_store
|
||||
.is_allowed(&Args {
|
||||
|
||||
@@ -144,7 +144,7 @@ impl Operation for AccountInfoHandler {
|
||||
let claims = cred.claims.as_ref().unwrap_or(&default_claims);
|
||||
|
||||
let cred_clone = cred.clone();
|
||||
let conditions = get_condition_values(&req.headers, &cred_clone);
|
||||
let conditions = get_condition_values(&req.headers, &cred_clone, None, None);
|
||||
let cred_clone = Arc::new(cred_clone);
|
||||
let conditions = Arc::new(conditions);
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ impl Operation for AddServiceAccount {
|
||||
groups: &cred.groups,
|
||||
action: Action::AdminAction(AdminAction::CreateServiceAccountAdminAction),
|
||||
bucket: "",
|
||||
conditions: &get_condition_values(&req.headers, &cred),
|
||||
conditions: &get_condition_values(&req.headers, &cred, None, None),
|
||||
is_owner: owner,
|
||||
object: "",
|
||||
claims: cred.claims.as_ref().unwrap_or(&HashMap::new()),
|
||||
@@ -263,7 +263,7 @@ impl Operation for UpdateServiceAccount {
|
||||
groups: &cred.groups,
|
||||
action: Action::AdminAction(AdminAction::UpdateServiceAccountAdminAction),
|
||||
bucket: "",
|
||||
conditions: &get_condition_values(&req.headers, &cred),
|
||||
conditions: &get_condition_values(&req.headers, &cred, None, None),
|
||||
is_owner: owner,
|
||||
object: "",
|
||||
claims: cred.claims.as_ref().unwrap_or(&HashMap::new()),
|
||||
@@ -356,7 +356,7 @@ impl Operation for InfoServiceAccount {
|
||||
groups: &cred.groups,
|
||||
action: Action::AdminAction(AdminAction::ListServiceAccountsAdminAction),
|
||||
bucket: "",
|
||||
conditions: &get_condition_values(&req.headers, &cred),
|
||||
conditions: &get_condition_values(&req.headers, &cred, None, None),
|
||||
is_owner: owner,
|
||||
object: "",
|
||||
claims: cred.claims.as_ref().unwrap_or(&HashMap::new()),
|
||||
@@ -484,7 +484,7 @@ impl Operation for ListServiceAccount {
|
||||
groups: &cred.groups,
|
||||
action: Action::AdminAction(AdminAction::UpdateServiceAccountAdminAction),
|
||||
bucket: "",
|
||||
conditions: &get_condition_values(&req.headers, &cred),
|
||||
conditions: &get_condition_values(&req.headers, &cred, None, None),
|
||||
is_owner: owner,
|
||||
object: "",
|
||||
claims: cred.claims.as_ref().unwrap_or(&HashMap::new()),
|
||||
@@ -582,7 +582,7 @@ impl Operation for DeleteServiceAccount {
|
||||
groups: &cred.groups,
|
||||
action: Action::AdminAction(AdminAction::RemoveServiceAccountAdminAction),
|
||||
bucket: "",
|
||||
conditions: &get_condition_values(&req.headers, &cred),
|
||||
conditions: &get_condition_values(&req.headers, &cred, None, None),
|
||||
is_owner: owner,
|
||||
object: "",
|
||||
claims: cred.claims.as_ref().unwrap_or(&HashMap::new()),
|
||||
|
||||
@@ -19,6 +19,7 @@ use rustfs_iam::error::Error as IamError;
|
||||
use rustfs_iam::sys::SESSION_POLICY_NAME;
|
||||
use rustfs_iam::sys::get_claims_from_token_with_secret;
|
||||
use rustfs_policy::auth;
|
||||
use rustfs_utils::http::ip::get_source_ip_raw;
|
||||
use s3s::S3Error;
|
||||
use s3s::S3ErrorCode;
|
||||
use s3s::S3Result;
|
||||
@@ -28,6 +29,40 @@ use s3s::auth::SimpleAuth;
|
||||
use s3s::s3_error;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::well_known::Rfc3339;
|
||||
|
||||
// Authentication type constants
|
||||
const JWT_ALGORITHM: &str = "Bearer ";
|
||||
const SIGN_V2_ALGORITHM: &str = "AWS ";
|
||||
const SIGN_V4_ALGORITHM: &str = "AWS4-HMAC-SHA256";
|
||||
const STREAMING_CONTENT_SHA256: &str = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD";
|
||||
const STREAMING_CONTENT_SHA256_TRAILER: &str = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER";
|
||||
pub const UNSIGNED_PAYLOAD_TRAILER: &str = "STREAMING-UNSIGNED-PAYLOAD-TRAILER";
|
||||
const ACTION_HEADER: &str = "Action";
|
||||
const AMZ_CREDENTIAL: &str = "X-Amz-Credential";
|
||||
const AMZ_ACCESS_KEY_ID: &str = "AWSAccessKeyId";
|
||||
pub const UNSIGNED_PAYLOAD: &str = "UNSIGNED-PAYLOAD";
|
||||
|
||||
// Authentication type enum
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub enum AuthType {
|
||||
#[default]
|
||||
Unknown,
|
||||
Anonymous,
|
||||
Presigned,
|
||||
PresignedV2,
|
||||
PostPolicy,
|
||||
StreamingSigned,
|
||||
Signed,
|
||||
SignedV2,
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
JWT,
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
STS,
|
||||
StreamingSignedTrailer,
|
||||
StreamingUnsignedTrailer,
|
||||
}
|
||||
|
||||
pub struct IAMAuth {
|
||||
simple_auth: SimpleAuth,
|
||||
@@ -167,7 +202,12 @@ pub fn get_session_token<'a>(uri: &'a Uri, hds: &'a HeaderMap) -> Option<&'a str
|
||||
.or_else(|| get_query_param(uri.query().unwrap_or_default(), "x-amz-security-token"))
|
||||
}
|
||||
|
||||
pub fn get_condition_values(header: &HeaderMap, cred: &auth::Credentials) -> HashMap<String, Vec<String>> {
|
||||
pub fn get_condition_values(
|
||||
header: &HeaderMap,
|
||||
cred: &auth::Credentials,
|
||||
version_id: Option<&str>,
|
||||
region: Option<&str>,
|
||||
) -> HashMap<String, Vec<String>> {
|
||||
let username = if cred.is_temp() || cred.is_service_account() {
|
||||
cred.parent_user.clone()
|
||||
} else {
|
||||
@@ -190,32 +230,83 @@ pub fn get_condition_values(header: &HeaderMap, cred: &auth::Credentials) -> Has
|
||||
"Anonymous"
|
||||
};
|
||||
|
||||
// Get current time
|
||||
let curr_time = OffsetDateTime::now_utc();
|
||||
let epoch_time = curr_time.unix_timestamp();
|
||||
|
||||
// Use provided version ID or empty string
|
||||
let vid = version_id.unwrap_or("");
|
||||
|
||||
// Determine auth type and signature version from headers
|
||||
let (auth_type, signature_version) = determine_auth_type_and_version(header);
|
||||
|
||||
// Get TLS status from header
|
||||
let is_tls = header
|
||||
.get("x-forwarded-proto")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s == "https")
|
||||
.or_else(|| {
|
||||
header
|
||||
.get("x-forwarded-scheme")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s == "https")
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
// Get remote address from header or use default
|
||||
let remote_addr = header
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.split(',').next())
|
||||
.or_else(|| header.get("x-real-ip").and_then(|v| v.to_str().ok()))
|
||||
.unwrap_or("127.0.0.1");
|
||||
|
||||
let mut args = HashMap::new();
|
||||
|
||||
// Add basic time and security info
|
||||
args.insert("CurrentTime".to_owned(), vec![curr_time.format(&Rfc3339).unwrap_or_default()]);
|
||||
args.insert("EpochTime".to_owned(), vec![epoch_time.to_string()]);
|
||||
args.insert("SecureTransport".to_owned(), vec![is_tls.to_string()]);
|
||||
args.insert("SourceIp".to_owned(), vec![get_source_ip_raw(header, remote_addr)]);
|
||||
|
||||
// Add user agent and referer
|
||||
if let Some(user_agent) = header.get("user-agent") {
|
||||
args.insert("UserAgent".to_owned(), vec![user_agent.to_str().unwrap_or("").to_string()]);
|
||||
}
|
||||
if let Some(referer) = header.get("referer") {
|
||||
args.insert("Referer".to_owned(), vec![referer.to_str().unwrap_or("").to_string()]);
|
||||
}
|
||||
|
||||
// Add user and principal info
|
||||
args.insert("userid".to_owned(), vec![username.clone()]);
|
||||
args.insert("username".to_owned(), vec![username]);
|
||||
args.insert("principaltype".to_owned(), vec![principal_type.to_string()]);
|
||||
|
||||
// Add version ID
|
||||
if !vid.is_empty() {
|
||||
args.insert("versionid".to_owned(), vec![vid.to_string()]);
|
||||
}
|
||||
|
||||
// Add signature version and auth type
|
||||
if !signature_version.is_empty() {
|
||||
args.insert("signatureversion".to_owned(), vec![signature_version]);
|
||||
}
|
||||
if !auth_type.is_empty() {
|
||||
args.insert("authType".to_owned(), vec![auth_type]);
|
||||
}
|
||||
|
||||
if let Some(lc) = region {
|
||||
if !lc.is_empty() {
|
||||
args.insert("LocationConstraint".to_owned(), vec![lc.to_string()]);
|
||||
}
|
||||
}
|
||||
|
||||
let mut clone_header = header.clone();
|
||||
if let Some(v) = clone_header.get("x-amz-signature-age") {
|
||||
args.insert("signatureAge".to_string(), vec![v.to_str().unwrap_or("").to_string()]);
|
||||
clone_header.remove("x-amz-signature-age");
|
||||
}
|
||||
|
||||
// TODO: parse_object_tags
|
||||
// if let Some(_user_tags) = clone_header.get("x-amz-tagging") {
|
||||
// TODO: parse_object_tags
|
||||
// if let Ok(tag) = tags::parse_object_tags(user_tags.to_str().unwrap_or("")) {
|
||||
// let tag_map = tag.to_map();
|
||||
// let mut keys = Vec::new();
|
||||
// for (k, v) in tag_map {
|
||||
// args.insert(format!("ExistingObjectTag/{}", k), vec![v.clone()]);
|
||||
// args.insert(format!("RequestObjectTag/{}", k), vec![v.clone()]);
|
||||
// keys.push(k);
|
||||
// }
|
||||
// args.insert("RequestObjectTagKeys".to_string(), keys);
|
||||
// }
|
||||
// }
|
||||
|
||||
for obj_lock in &[
|
||||
"x-amz-object-lock-mode",
|
||||
"x-amz-object-lock-legal-hold",
|
||||
@@ -250,37 +341,6 @@ pub fn get_condition_values(header: &HeaderMap, cred: &auth::Credentials) -> Has
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: add from url query
|
||||
// let mut clone_url_values = r
|
||||
// .uri()
|
||||
// .query()
|
||||
// .unwrap_or("")
|
||||
// .split('&')
|
||||
// .map(|s| {
|
||||
// let mut split = s.split('=');
|
||||
// (split.next().unwrap_or("").to_string(), split.next().unwrap_or("").to_string())
|
||||
// })
|
||||
// .collect::<HashMap<String, String>>();
|
||||
|
||||
// for obj_lock in &[
|
||||
// "x-amz-object-lock-mode",
|
||||
// "x-amz-object-lock-legal-hold",
|
||||
// "x-amz-object-lock-retain-until-date",
|
||||
// ] {
|
||||
// if let Some(values) = clone_url_values.get(*obj_lock) {
|
||||
// args.insert(obj_lock.trim_start_matches("x-amz-").to_string(), vec![values.clone()]);
|
||||
// }
|
||||
// clone_url_values.remove(*obj_lock);
|
||||
// }
|
||||
|
||||
// for (key, values) in clone_url_values.iter() {
|
||||
// if let Some(existing_values) = args.get_mut(key) {
|
||||
// existing_values.push(values.clone());
|
||||
// } else {
|
||||
// args.insert(key.clone(), vec![values.clone()]);
|
||||
// }
|
||||
// }
|
||||
|
||||
if let Some(claims) = &cred.claims {
|
||||
for (k, v) in claims {
|
||||
if let Some(v_str) = v.as_str() {
|
||||
@@ -310,6 +370,152 @@ pub fn get_condition_values(header: &HeaderMap, cred: &auth::Credentials) -> Has
|
||||
args
|
||||
}
|
||||
|
||||
// Get request authentication type
|
||||
pub fn get_request_auth_type(header: &HeaderMap) -> AuthType {
|
||||
if is_request_signature_v2(header) {
|
||||
AuthType::SignedV2
|
||||
} else if is_request_presigned_signature_v2(header) {
|
||||
AuthType::PresignedV2
|
||||
} else if is_request_sign_streaming_v4(header) {
|
||||
AuthType::StreamingSigned
|
||||
} else if is_request_sign_streaming_trailer_v4(header) {
|
||||
AuthType::StreamingSignedTrailer
|
||||
} else if is_request_unsigned_trailer_v4(header) {
|
||||
AuthType::StreamingUnsignedTrailer
|
||||
} else if is_request_signature_v4(header) {
|
||||
AuthType::Signed
|
||||
} else if is_request_presigned_signature_v4(header) {
|
||||
AuthType::Presigned
|
||||
} else if is_request_jwt(header) {
|
||||
AuthType::JWT
|
||||
} else if is_request_post_policy_signature_v4(header) {
|
||||
AuthType::PostPolicy
|
||||
} else if is_request_sts(header) {
|
||||
AuthType::STS
|
||||
} else if is_request_anonymous(header) {
|
||||
AuthType::Anonymous
|
||||
} else {
|
||||
AuthType::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to determine auth type and signature version
|
||||
fn determine_auth_type_and_version(header: &HeaderMap) -> (String, String) {
|
||||
match get_request_auth_type(header) {
|
||||
AuthType::JWT => ("JWT".to_string(), String::new()),
|
||||
AuthType::SignedV2 => ("REST-HEADER".to_string(), "AWS2".to_string()),
|
||||
AuthType::PresignedV2 => ("REST-QUERY-STRING".to_string(), "AWS2".to_string()),
|
||||
AuthType::StreamingSigned | AuthType::StreamingSignedTrailer | AuthType::StreamingUnsignedTrailer => {
|
||||
("REST-HEADER".to_string(), "AWS4-HMAC-SHA256".to_string())
|
||||
}
|
||||
AuthType::Signed => ("REST-HEADER".to_string(), "AWS4-HMAC-SHA256".to_string()),
|
||||
AuthType::Presigned => ("REST-QUERY-STRING".to_string(), "AWS4-HMAC-SHA256".to_string()),
|
||||
AuthType::PostPolicy => ("POST".to_string(), String::new()),
|
||||
AuthType::STS => ("STS".to_string(), String::new()),
|
||||
AuthType::Anonymous => ("Anonymous".to_string(), String::new()),
|
||||
AuthType::Unknown => (String::new(), String::new()),
|
||||
}
|
||||
}
|
||||
|
||||
// Verify if request has JWT
|
||||
fn is_request_jwt(header: &HeaderMap) -> bool {
|
||||
if let Some(auth) = header.get("authorization") {
|
||||
if let Ok(auth_str) = auth.to_str() {
|
||||
return auth_str.starts_with(JWT_ALGORITHM);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if request has AWS Signature Version '4'
|
||||
fn is_request_signature_v4(header: &HeaderMap) -> bool {
|
||||
if let Some(auth) = header.get("authorization") {
|
||||
if let Ok(auth_str) = auth.to_str() {
|
||||
return auth_str.starts_with(SIGN_V4_ALGORITHM);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if request has AWS Signature Version '2'
|
||||
fn is_request_signature_v2(header: &HeaderMap) -> bool {
|
||||
if let Some(auth) = header.get("authorization") {
|
||||
if let Ok(auth_str) = auth.to_str() {
|
||||
return !auth_str.starts_with(SIGN_V4_ALGORITHM) && auth_str.starts_with(SIGN_V2_ALGORITHM);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if request has AWS PreSign Version '4'
|
||||
pub(crate) fn is_request_presigned_signature_v4(header: &HeaderMap) -> bool {
|
||||
if let Some(credential) = header.get(AMZ_CREDENTIAL) {
|
||||
return !credential.to_str().unwrap_or("").is_empty();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify request has AWS PreSign Version '2'
|
||||
fn is_request_presigned_signature_v2(header: &HeaderMap) -> bool {
|
||||
if let Some(access_key) = header.get(AMZ_ACCESS_KEY_ID) {
|
||||
return !access_key.to_str().unwrap_or("").is_empty();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if request has AWS Post policy Signature Version '4'
|
||||
fn is_request_post_policy_signature_v4(header: &HeaderMap) -> bool {
|
||||
if let Some(content_type) = header.get("content-type") {
|
||||
if let Ok(ct) = content_type.to_str() {
|
||||
return ct.contains("multipart/form-data");
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if the request has AWS Streaming Signature Version '4'
|
||||
fn is_request_sign_streaming_v4(header: &HeaderMap) -> bool {
|
||||
if let Some(content_sha256) = header.get("x-amz-content-sha256") {
|
||||
if let Ok(sha256_str) = content_sha256.to_str() {
|
||||
return sha256_str == STREAMING_CONTENT_SHA256;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if the request has AWS Streaming Signature Version '4' with trailer
|
||||
fn is_request_sign_streaming_trailer_v4(header: &HeaderMap) -> bool {
|
||||
if let Some(content_sha256) = header.get("x-amz-content-sha256") {
|
||||
if let Ok(sha256_str) = content_sha256.to_str() {
|
||||
return sha256_str == STREAMING_CONTENT_SHA256_TRAILER;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if the request has AWS Streaming Signature Version '4' with unsigned content and trailer
|
||||
fn is_request_unsigned_trailer_v4(header: &HeaderMap) -> bool {
|
||||
if let Some(content_sha256) = header.get("x-amz-content-sha256") {
|
||||
if let Ok(sha256_str) = content_sha256.to_str() {
|
||||
return sha256_str == UNSIGNED_PAYLOAD_TRAILER;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if request is STS (Security Token Service)
|
||||
fn is_request_sts(header: &HeaderMap) -> bool {
|
||||
if let Some(action) = header.get(ACTION_HEADER) {
|
||||
return !action.to_str().unwrap_or("").is_empty();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Verify if request is anonymous
|
||||
fn is_request_anonymous(header: &HeaderMap) -> bool {
|
||||
header.get("authorization").is_none()
|
||||
}
|
||||
|
||||
pub fn get_query_param<'a>(query: &'a str, param_name: &str) -> Option<&'a str> {
|
||||
let param_name = param_name.to_lowercase();
|
||||
|
||||
@@ -549,7 +755,7 @@ mod tests {
|
||||
let cred = create_test_credentials();
|
||||
let headers = HeaderMap::new();
|
||||
|
||||
let conditions = get_condition_values(&headers, &cred);
|
||||
let conditions = get_condition_values(&headers, &cred, None, None);
|
||||
|
||||
assert_eq!(conditions.get("userid"), Some(&vec!["test-access-key".to_string()]));
|
||||
assert_eq!(conditions.get("username"), Some(&vec!["test-access-key".to_string()]));
|
||||
@@ -561,7 +767,7 @@ mod tests {
|
||||
let cred = create_temp_credentials();
|
||||
let headers = HeaderMap::new();
|
||||
|
||||
let conditions = get_condition_values(&headers, &cred);
|
||||
let conditions = get_condition_values(&headers, &cred, None, None);
|
||||
|
||||
assert_eq!(conditions.get("userid"), Some(&vec!["parent-user".to_string()]));
|
||||
assert_eq!(conditions.get("username"), Some(&vec!["parent-user".to_string()]));
|
||||
@@ -573,7 +779,7 @@ mod tests {
|
||||
let cred = create_service_account_credentials();
|
||||
let headers = HeaderMap::new();
|
||||
|
||||
let conditions = get_condition_values(&headers, &cred);
|
||||
let conditions = get_condition_values(&headers, &cred, None, None);
|
||||
|
||||
assert_eq!(conditions.get("userid"), Some(&vec!["service-parent".to_string()]));
|
||||
assert_eq!(conditions.get("username"), Some(&vec!["service-parent".to_string()]));
|
||||
@@ -588,7 +794,7 @@ mod tests {
|
||||
headers.insert("x-amz-object-lock-mode", HeaderValue::from_static("GOVERNANCE"));
|
||||
headers.insert("x-amz-object-lock-retain-until-date", HeaderValue::from_static("2024-12-31T23:59:59Z"));
|
||||
|
||||
let conditions = get_condition_values(&headers, &cred);
|
||||
let conditions = get_condition_values(&headers, &cred, None, None);
|
||||
|
||||
assert_eq!(conditions.get("object-lock-mode"), Some(&vec!["GOVERNANCE".to_string()]));
|
||||
assert_eq!(
|
||||
@@ -603,7 +809,7 @@ mod tests {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-amz-signature-age", HeaderValue::from_static("300"));
|
||||
|
||||
let conditions = get_condition_values(&headers, &cred);
|
||||
let conditions = get_condition_values(&headers, &cred, None, None);
|
||||
|
||||
assert_eq!(conditions.get("signatureAge"), Some(&vec!["300".to_string()]));
|
||||
// Verify the header is removed after processing
|
||||
@@ -620,7 +826,7 @@ mod tests {
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
|
||||
let conditions = get_condition_values(&headers, &cred);
|
||||
let conditions = get_condition_values(&headers, &cred, None, None);
|
||||
|
||||
assert_eq!(conditions.get("username"), Some(&vec!["ldap-user".to_string()]));
|
||||
assert_eq!(conditions.get("groups"), Some(&vec!["group1".to_string(), "group2".to_string()]));
|
||||
@@ -633,7 +839,7 @@ mod tests {
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
|
||||
let conditions = get_condition_values(&headers, &cred);
|
||||
let conditions = get_condition_values(&headers, &cred, None, None);
|
||||
|
||||
assert_eq!(
|
||||
conditions.get("groups"),
|
||||
@@ -758,4 +964,138 @@ mod tests {
|
||||
|
||||
assert!(!cred.is_service_account());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_jwt() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("authorization", HeaderValue::from_static("Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"));
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::JWT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_signature_v2() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"authorization",
|
||||
HeaderValue::from_static("AWS AKIAIOSFODNN7EXAMPLE:frJIUN8DYpKDtOLCwo//bqJZQ1iY="),
|
||||
);
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::SignedV2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_signature_v4() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"authorization",
|
||||
HeaderValue::from_static("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request"),
|
||||
);
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::Signed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_presigned_v2() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("AWSAccessKeyId", HeaderValue::from_static("AKIAIOSFODNN7EXAMPLE"));
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::PresignedV2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_presigned_v4() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"X-Amz-Credential",
|
||||
HeaderValue::from_static("AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request"),
|
||||
);
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::Presigned);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_post_policy() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"content-type",
|
||||
HeaderValue::from_static("multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"),
|
||||
);
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::PostPolicy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_streaming_signed() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-amz-content-sha256", HeaderValue::from_static("STREAMING-AWS4-HMAC-SHA256-PAYLOAD"));
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::StreamingSigned);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_streaming_signed_trailer() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"x-amz-content-sha256",
|
||||
HeaderValue::from_static("STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER"),
|
||||
);
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::StreamingSignedTrailer);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_streaming_unsigned_trailer() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-amz-content-sha256", HeaderValue::from_static("STREAMING-UNSIGNED-PAYLOAD-TRAILER"));
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::StreamingUnsignedTrailer);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_sts() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Action", HeaderValue::from_static("AssumeRole"));
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::STS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_anonymous() {
|
||||
let headers = HeaderMap::new();
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::Anonymous);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_request_auth_type_unknown() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("authorization", HeaderValue::from_static("CustomAuth token123"));
|
||||
|
||||
let auth_type = get_request_auth_type(&headers);
|
||||
|
||||
assert_eq!(auth_type, AuthType::Unknown);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ pub(crate) struct ReqInfo {
|
||||
pub bucket: Option<String>,
|
||||
pub object: Option<String>,
|
||||
pub version_id: Option<String>,
|
||||
pub region: Option<String>,
|
||||
}
|
||||
|
||||
/// Authorizes the request based on the action and credentials.
|
||||
@@ -48,7 +49,7 @@ pub async fn authorize_request<T>(req: &mut S3Request<T>, action: Action) -> S3R
|
||||
|
||||
let default_claims = HashMap::new();
|
||||
let claims = cred.claims.as_ref().unwrap_or(&default_claims);
|
||||
let conditions = get_condition_values(&req.headers, cred);
|
||||
let conditions = get_condition_values(&req.headers, cred, req_info.version_id.as_deref(), None);
|
||||
|
||||
if action == Action::S3Action(S3Action::DeleteObjectAction)
|
||||
&& req_info.version_id.is_some()
|
||||
@@ -104,7 +105,12 @@ pub async fn authorize_request<T>(req: &mut S3Request<T>, action: Action) -> S3R
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
let conditions = get_condition_values(&req.headers, &auth::Credentials::default());
|
||||
let conditions = get_condition_values(
|
||||
&req.headers,
|
||||
&auth::Credentials::default(),
|
||||
req_info.version_id.as_deref(),
|
||||
req.region.as_deref(),
|
||||
);
|
||||
|
||||
if action != Action::S3Action(S3Action::ListAllMyBucketsAction) {
|
||||
if PolicySys::is_allowed(&BucketPolicyArgs {
|
||||
@@ -181,6 +187,7 @@ impl S3Access for FS {
|
||||
let req_info = ReqInfo {
|
||||
cred,
|
||||
is_owner,
|
||||
region: rustfs_ecstore::global::get_global_region(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use crate::auth::get_condition_values;
|
||||
use crate::error::ApiError;
|
||||
use crate::storage::options::get_content_sha256;
|
||||
use crate::storage::{
|
||||
access::{ReqInfo, authorize_request},
|
||||
options::{
|
||||
@@ -97,8 +98,7 @@ use rustfs_targets::{
|
||||
arn::{TargetID, TargetIDError},
|
||||
};
|
||||
use rustfs_utils::http::{
|
||||
AMZ_CHECKSUM_MODE, AMZ_CHECKSUM_TYPE, AMZ_CONTENT_SHA256, AMZ_META_UNENCRYPTED_CONTENT_LENGTH,
|
||||
AMZ_META_UNENCRYPTED_CONTENT_MD5,
|
||||
AMZ_CHECKSUM_MODE, AMZ_CHECKSUM_TYPE, AMZ_META_UNENCRYPTED_CONTENT_LENGTH, AMZ_META_UNENCRYPTED_CONTENT_MD5,
|
||||
};
|
||||
use rustfs_utils::{
|
||||
CompressionAlgorithm,
|
||||
@@ -393,13 +393,7 @@ impl FS {
|
||||
None
|
||||
};
|
||||
|
||||
let sha256hex = req.headers.get(AMZ_CONTENT_SHA256).and_then(|v| {
|
||||
v.to_str()
|
||||
.ok()
|
||||
.filter(|&v| v != "UNSIGNED-PAYLOAD" && v != "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
|
||||
.map(|v| v.to_string())
|
||||
});
|
||||
|
||||
let sha256hex = get_content_sha256(&req.headers);
|
||||
let actual_size = size;
|
||||
|
||||
let reader: Box<dyn Reader> = Box::new(WarpReader::new(body));
|
||||
@@ -2383,12 +2377,7 @@ impl S3 for FS {
|
||||
None
|
||||
};
|
||||
|
||||
let mut sha256hex = req.headers.get(AMZ_CONTENT_SHA256).and_then(|v| {
|
||||
v.to_str()
|
||||
.ok()
|
||||
.filter(|&v| v != "UNSIGNED-PAYLOAD" && v != "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
|
||||
.map(|v| v.to_string())
|
||||
});
|
||||
let mut sha256hex = get_content_sha256(&req.headers);
|
||||
|
||||
if is_compressible(&req.headers, &key) && size > MIN_COMPRESSIBLE_SIZE as i64 {
|
||||
metadata.insert(
|
||||
@@ -2918,12 +2907,7 @@ impl S3 for FS {
|
||||
None
|
||||
};
|
||||
|
||||
let mut sha256hex = req.headers.get(AMZ_CONTENT_SHA256).and_then(|v| {
|
||||
v.to_str()
|
||||
.ok()
|
||||
.filter(|&v| v != "UNSIGNED-PAYLOAD" && v != "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
|
||||
.map(|v| v.to_string())
|
||||
});
|
||||
let mut sha256hex = get_content_sha256(&req.headers);
|
||||
|
||||
if is_compressible {
|
||||
let mut hrd = HashReader::new(reader, size, actual_size, md5hex, sha256hex, false).map_err(ApiError::from)?;
|
||||
@@ -3776,7 +3760,7 @@ impl S3 for FS {
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
|
||||
let conditions = get_condition_values(&req.headers, &auth::Credentials::default());
|
||||
let conditions = get_condition_values(&req.headers, &auth::Credentials::default(), None, None);
|
||||
|
||||
let read_only = PolicySys::is_allowed(&BucketPolicyArgs {
|
||||
bucket: &bucket,
|
||||
|
||||
@@ -17,7 +17,12 @@ use rustfs_ecstore::bucket::versioning_sys::BucketVersioningSys;
|
||||
use rustfs_ecstore::error::Result;
|
||||
use rustfs_ecstore::error::StorageError;
|
||||
|
||||
use crate::auth::UNSIGNED_PAYLOAD;
|
||||
use crate::auth::UNSIGNED_PAYLOAD_TRAILER;
|
||||
use rustfs_ecstore::store_api::{HTTPPreconditions, HTTPRangeSpec, ObjectOptions};
|
||||
use rustfs_policy::service_type::ServiceType;
|
||||
use rustfs_utils::hash::EMPTY_STRING_SHA256_HASH;
|
||||
use rustfs_utils::http::AMZ_CONTENT_SHA256;
|
||||
use rustfs_utils::http::RESERVED_METADATA_PREFIX_LOWER;
|
||||
use rustfs_utils::http::RUSTFS_BUCKET_REPLICATION_DELETE_MARKER;
|
||||
use rustfs_utils::http::RUSTFS_BUCKET_REPLICATION_REQUEST;
|
||||
@@ -30,6 +35,10 @@ use std::sync::LazyLock;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::AuthType;
|
||||
use crate::auth::get_request_auth_type;
|
||||
use crate::auth::is_request_presigned_signature_v4;
|
||||
|
||||
/// Creates options for deleting an object in a bucket.
|
||||
pub async fn del_opts(
|
||||
bucket: &str,
|
||||
@@ -414,6 +423,100 @@ pub fn parse_copy_source_range(range_str: &str) -> S3Result<HTTPRangeSpec> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_content_sha256(headers: &HeaderMap<HeaderValue>) -> Option<String> {
|
||||
match get_request_auth_type(headers) {
|
||||
AuthType::Presigned | AuthType::Signed => {
|
||||
if skip_content_sha256_cksum(headers) {
|
||||
None
|
||||
} else {
|
||||
Some(get_content_sha256_cksum(headers, ServiceType::S3))
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// skip_content_sha256_cksum returns true if caller needs to skip
|
||||
/// payload checksum, false if not.
|
||||
fn skip_content_sha256_cksum(headers: &HeaderMap<HeaderValue>) -> bool {
|
||||
let content_sha256 = if is_request_presigned_signature_v4(headers) {
|
||||
// For presigned requests, check query params first, then headers
|
||||
// Note: In a real implementation, you would need to check query parameters
|
||||
// For now, we'll just check headers
|
||||
headers.get(AMZ_CONTENT_SHA256)
|
||||
} else {
|
||||
headers.get(AMZ_CONTENT_SHA256)
|
||||
};
|
||||
|
||||
// Skip if no header was set
|
||||
let Some(header_value) = content_sha256 else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let Ok(value) = header_value.to_str() else {
|
||||
return true;
|
||||
};
|
||||
|
||||
// If x-amz-content-sha256 is set and the value is not
|
||||
// 'UNSIGNED-PAYLOAD' we should validate the content sha256.
|
||||
match value {
|
||||
v if v == UNSIGNED_PAYLOAD || v == UNSIGNED_PAYLOAD_TRAILER => true,
|
||||
v if v == EMPTY_STRING_SHA256_HASH => {
|
||||
// some broken clients set empty-sha256
|
||||
// with > 0 content-length in the body,
|
||||
// we should skip such clients and allow
|
||||
// blindly such insecure clients only if
|
||||
// S3 strict compatibility is disabled.
|
||||
|
||||
// We return true only in situations when
|
||||
// deployment has asked RustFS to allow for
|
||||
// such broken clients and content-length > 0.
|
||||
// For now, we'll assume strict compatibility is disabled
|
||||
// In a real implementation, you would check a global config
|
||||
if let Some(content_length) = headers.get("content-length") {
|
||||
if let Ok(length_str) = content_length.to_str() {
|
||||
if let Ok(length) = length_str.parse::<i64>() {
|
||||
return length > 0; // && !global_server_ctxt.strict_s3_compat
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns SHA256 for calculating canonical-request.
|
||||
fn get_content_sha256_cksum(headers: &HeaderMap<HeaderValue>, service_type: ServiceType) -> String {
|
||||
if service_type == ServiceType::STS {
|
||||
// For STS requests, we would need to read the body and calculate SHA256
|
||||
// This is a simplified implementation - in practice you'd need access to the request body
|
||||
// For now, we'll return a placeholder
|
||||
return "sts-body-sha256-placeholder".to_string();
|
||||
}
|
||||
|
||||
let (default_sha256_cksum, content_sha256) = if is_request_presigned_signature_v4(headers) {
|
||||
// For a presigned request we look at the query param for sha256.
|
||||
// X-Amz-Content-Sha256, if not set in presigned requests, checksum
|
||||
// will default to 'UNSIGNED-PAYLOAD'.
|
||||
(UNSIGNED_PAYLOAD.to_string(), headers.get(AMZ_CONTENT_SHA256))
|
||||
} else {
|
||||
// X-Amz-Content-Sha256, if not set in signed requests, checksum
|
||||
// will default to sha256([]byte("")).
|
||||
(EMPTY_STRING_SHA256_HASH.to_string(), headers.get(AMZ_CONTENT_SHA256))
|
||||
};
|
||||
|
||||
// We found 'X-Amz-Content-Sha256' return the captured value.
|
||||
if let Some(header_value) = content_sha256 {
|
||||
if let Ok(value) = header_value.to_str() {
|
||||
return value.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// We couldn't find 'X-Amz-Content-Sha256'.
|
||||
default_sha256_cksum
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user