diff --git a/rustfs/src/storage/sse.rs b/rustfs/src/storage/sse.rs index dccdd17d..ff09d250 100644 --- a/rustfs/src/storage/sse.rs +++ b/rustfs/src/storage/sse.rs @@ -21,6 +21,11 @@ //! //! ## Architecture //! +//! ### Unified API +//! The module provides two core functions that automatically route to the correct encryption method: +//! - `apply_encryption()` - Unified encryption entry point +//! - `apply_decryption()` - Unified decryption entry point +//! //! ### Managed SSE (SSE-S3 / SSE-KMS) //! - Keys are managed by the server-side KMS service //! - Data keys are generated and encrypted by KMS @@ -34,18 +39,36 @@ //! ## Usage Example //! //! ```rust,ignore -//! // Apply managed SSE encryption -//! if let Some(material) = apply_managed_sse( -//! &bucket, &key, &sse_algorithm, kms_key_id, actual_size -//! ).await? { -//! reader = material.wrap_encrypt_reader(reader)?; -//! metadata.extend(material.headers); +//! // Unified encryption API +//! let request = EncryptionRequest { +//! bucket: &bucket, +//! key: &key, +//! server_side_encryption: effective_sse.as_ref(), +//! ssekms_key_id: effective_kms_key_id.as_deref(), +//! sse_customer_algorithm: sse_customer_algorithm.as_ref(), +//! sse_customer_key: sse_customer_key.as_deref(), +//! sse_customer_key_md5: sse_customer_key_md5.as_deref(), +//! content_size: actual_size, +//! part_number: None, +//! }; +//! +//! if let Some(material) = apply_encryption(request).await? { +//! reader = material.wrap_reader(reader)?; +//! metadata.extend(material.metadata); //! } //! -//! // Apply SSE-C encryption -//! if let Some(params) = sse_customer_params { -//! let validated = validate_ssec_params(¶ms)?; -//! reader = apply_ssec_encryption(reader, &validated, &bucket, &key)?; +//! // Unified decryption API +//! let request = DecryptionRequest { +//! bucket: &bucket, +//! key: &key, +//! metadata: &metadata, +//! sse_customer_key: sse_customer_key.as_deref(), +//! sse_customer_key_md5: sse_customer_key_md5.as_deref(), +//! part_number: None, +//! }; +//! +//! if let Some(material) = apply_decryption(request).await? { +//! reader = material.wrap_reader(reader)?; //! } //! ``` @@ -67,7 +90,475 @@ use tracing::error; use crate::error::ApiError; // ============================================================================ -// Public Types +// Core Types - Unified Encryption/Decryption API +// ============================================================================ + +/// Request parameters for unified encryption +#[derive(Debug, Clone)] +pub struct EncryptionRequest<'a> { + /// Bucket name + pub bucket: &'a str, + /// Object key + pub key: &'a str, + /// Server-side encryption algorithm (SSE-S3 or SSE-KMS) + pub server_side_encryption: Option<&'a ServerSideEncryption>, + /// KMS key ID (for SSE-KMS) + pub ssekms_key_id: Option<&'a str>, + /// SSE-C algorithm (customer-provided key) + pub sse_customer_algorithm: Option<&'a ServerSideEncryption>, + /// SSE-C key (Base64-encoded) + pub sse_customer_key: Option<&'a str>, + /// SSE-C key MD5 (Base64-encoded) + pub sse_customer_key_md5: Option<&'a str>, + /// Content size (for metadata) + pub content_size: i64, + /// Part number (for multipart upload, None for single-part) + pub part_number: Option, +} + +/// Request parameters for unified decryption +#[derive(Debug)] +pub struct DecryptionRequest<'a> { + /// Bucket name + pub bucket: &'a str, + /// Object key + pub key: &'a str, + /// Object metadata containing encryption headers + pub metadata: &'a HashMap, + /// SSE-C key (Base64-encoded) - required if object was encrypted with SSE-C + pub sse_customer_key: Option<&'a str>, + /// SSE-C key MD5 (Base64-encoded) - required if object was encrypted with SSE-C + pub sse_customer_key_md5: Option<&'a str>, + /// Part number (for multipart upload, None for single-part) + pub part_number: Option, +} + +/// Unified encryption material returned by `apply_encryption()` +#[derive(Debug)] +pub struct EncryptionMaterial { + /// Encryption key bytes + pub key_bytes: [u8; 32], + /// Nonce/IV for encryption + pub nonce: [u8; 12], + /// Metadata to store with the object + pub metadata: HashMap, + /// Encryption type for logging/debugging + pub encryption_type: EncryptionType, + /// KMS key ID (for managed SSE only) + pub kms_key_id: Option, +} + +/// Unified decryption material returned by `apply_decryption()` +#[derive(Debug)] +pub struct DecryptionMaterial { + /// Decryption key bytes + pub key_bytes: [u8; 32], + /// Nonce/IV for decryption + pub nonce: [u8; 12], + /// Original unencrypted size (if available) + pub original_size: Option, + /// Encryption type for logging/debugging + pub encryption_type: EncryptionType, +} + +/// Type of encryption used +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EncryptionType { + /// SSE-S3 (AES256) + SseS3, + /// SSE-KMS (aws:kms) + SseKms, + /// SSE-C (customer-provided key) + SseC, +} + +impl EncryptionMaterial { + /// Wrap a reader with encryption + pub fn wrap_reader(&self, reader: R) -> Box> + where + R: Reader + 'static, + { + Box::new(EncryptReader::new(reader, self.key_bytes, self.nonce)) + } +} + +impl DecryptionMaterial { + /// Wrap a reader with decryption + pub fn wrap_reader(&self, reader: R) -> Box> + where + R: Reader + 'static, + { + Box::new(DecryptReader::new(reader, self.key_bytes, self.nonce)) + } +} + +// ============================================================================ +// Core API - Unified Encryption/Decryption Entry Points +// ============================================================================ + +/// **Core API**: Apply encryption based on request parameters +/// +/// This function automatically routes to the appropriate encryption method: +/// - SSE-C if customer key is provided +/// - SSE-S3/SSE-KMS if server-side encryption is requested +/// - None if no encryption is requested +/// +/// # Arguments +/// * `request` - Encryption request with all possible encryption parameters +/// +/// # Returns +/// * `Ok(Some(material))` - Encryption should be applied with the returned material +/// * `Ok(None)` - No encryption requested +/// * `Err` - Encryption configuration error +/// +/// # Example +/// ```rust,ignore +/// let request = EncryptionRequest { +/// bucket: &bucket, +/// key: &key, +/// server_side_encryption: effective_sse.as_ref(), +/// ssekms_key_id: effective_kms_key_id.as_deref(), +/// sse_customer_algorithm: sse_customer_algorithm.as_ref(), +/// sse_customer_key: sse_customer_key.as_deref(), +/// sse_customer_key_md5: sse_customer_key_md5.as_deref(), +/// content_size: actual_size, +/// part_number: None, +/// }; +/// +/// if let Some(material) = apply_encryption(request).await? { +/// reader = material.wrap_reader(reader)?; +/// metadata.extend(material.metadata); +/// } +/// ``` +pub async fn apply_encryption(request: EncryptionRequest<'_>) -> Result, ApiError> { + // Priority 1: SSE-C (customer-provided key) + if let (Some(algorithm), Some(key), Some(key_md5)) = + (request.sse_customer_algorithm, request.sse_customer_key, request.sse_customer_key_md5) + { + return apply_ssec_encryption_material(request.bucket, request.key, algorithm, key, key_md5, request.content_size, request.part_number) + .await + .map(Some); + } + + // Priority 2: Managed SSE (SSE-S3 or SSE-KMS) + if let Some(sse_algorithm) = request.server_side_encryption { + if is_managed_sse(sse_algorithm) { + return apply_managed_encryption_material( + request.bucket, + request.key, + sse_algorithm, + request.ssekms_key_id, + request.content_size, + request.part_number, + ) + .await + .map(Some); + } + } + + // No encryption requested + Ok(None) +} + +/// **Core API**: Apply decryption based on stored metadata +/// +/// This function automatically detects the encryption type from metadata: +/// - SSE-C if customer key is provided +/// - SSE-S3/SSE-KMS if managed encryption metadata is found +/// - None if object is not encrypted +/// +/// # Arguments +/// * `request` - Decryption request with metadata and optional customer key +/// +/// # Returns +/// * `Ok(Some(material))` - Decryption should be applied with the returned material +/// * `Ok(None)` - Object is not encrypted +/// * `Err` - Decryption configuration error or key mismatch +/// +/// # Example +/// ```rust,ignore +/// let request = DecryptionRequest { +/// bucket: &bucket, +/// key: &key, +/// metadata: &metadata, +/// sse_customer_key: sse_customer_key.as_deref(), +/// sse_customer_key_md5: sse_customer_key_md5.as_deref(), +/// part_number: None, +/// }; +/// +/// if let Some(material) = apply_decryption(request).await? { +/// reader = material.wrap_reader(reader)?; +/// } +/// ``` +pub async fn apply_decryption(request: DecryptionRequest<'_>) -> Result, ApiError> { + // Check for SSE-C encryption + if request.metadata.contains_key("x-amz-server-side-encryption-customer-algorithm") { + let (key, key_md5) = match (request.sse_customer_key, request.sse_customer_key_md5) { + (Some(k), Some(md5)) => (k, md5), + _ => { + return Err(ApiError::from(StorageError::other( + "Object is encrypted with SSE-C but no customer key provided", + ))); + } + }; + + return apply_ssec_decryption_material(request.bucket, request.key, request.metadata, key, key_md5, request.part_number) + .await + .map(Some); + } + + // Check for managed SSE encryption + if request.metadata.contains_key("x-rustfs-encryption-key") { + return apply_managed_decryption_material(request.bucket, request.key, request.metadata, request.part_number) + .await + .map(|opt| opt); + } + + // No encryption detected + Ok(None) +} + +// ============================================================================ +// Internal Implementation - SSE-C +// ============================================================================ + +async fn apply_ssec_encryption_material( + bucket: &str, + key: &str, + algorithm: &ServerSideEncryption, + sse_key: &str, + sse_key_md5: &str, + content_size: i64, + part_number: Option, +) -> Result { + let params = SsecParams { + algorithm: algorithm.as_str().to_string(), + key: sse_key.to_string(), + key_md5: sse_key_md5.to_string(), + }; + + let validated = validate_ssec_params(¶ms)?; + + // Generate nonce (deterministic for SSE-C) + let base_nonce = generate_ssec_nonce(bucket, key); + let nonce = if let Some(part_num) = part_number { + derive_part_nonce(base_nonce, part_num) + } else { + base_nonce + }; + + // Build metadata + let mut metadata = HashMap::new(); + metadata.insert( + "x-amz-server-side-encryption-customer-algorithm".to_string(), + validated.algorithm.clone(), + ); + metadata.insert( + "x-amz-server-side-encryption-customer-key-md5".to_string(), + validated.key_md5.clone(), + ); + metadata.insert( + "x-amz-server-side-encryption-customer-original-size".to_string(), + content_size.to_string(), + ); + + Ok(EncryptionMaterial { + key_bytes: validated.key_bytes, + nonce, + metadata, + encryption_type: EncryptionType::SseC, + kms_key_id: None, + }) +} + +async fn apply_ssec_decryption_material( + bucket: &str, + key: &str, + metadata: &HashMap, + sse_key: &str, + sse_key_md5: &str, + part_number: Option, +) -> Result { + // Verify key matches + let stored_md5 = metadata.get("x-amz-server-side-encryption-customer-key-md5"); + verify_ssec_key_match(sse_key_md5, stored_md5)?; + + // Validate provided key + let algorithm = metadata + .get("x-amz-server-side-encryption-customer-algorithm") + .map(|s| s.as_str()) + .unwrap_or("AES256"); + + let params = SsecParams { + algorithm: algorithm.to_string(), + key: sse_key.to_string(), + key_md5: sse_key_md5.to_string(), + }; + + let validated = validate_ssec_params(¶ms)?; + + // Generate nonce (same as encryption) + let base_nonce = generate_ssec_nonce(bucket, key); + let nonce = if let Some(part_num) = part_number { + derive_part_nonce(base_nonce, part_num) + } else { + base_nonce + }; + + let original_size = metadata + .get("x-amz-server-side-encryption-customer-original-size") + .and_then(|s| s.parse::().ok()); + + Ok(DecryptionMaterial { + key_bytes: validated.key_bytes, + nonce, + original_size, + encryption_type: EncryptionType::SseC, + }) +} + +// ============================================================================ +// Internal Implementation - Managed SSE (SSE-S3 / SSE-KMS) +// ============================================================================ + +async fn apply_managed_encryption_material( + bucket: &str, + key: &str, + algorithm: &ServerSideEncryption, + kms_key_id: Option<&str>, + content_size: i64, + part_number: Option, +) -> Result { + // For multipart, we only generate keys at CompleteMultipartUpload + // During UploadPart, we use the same base nonce with incremented counter + // This is handled externally, so here we just generate the base material + + let Some(service) = get_global_encryption_service().await else { + return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized"))); + }; + + if !is_managed_sse(algorithm) { + return Err(ApiError::from(StorageError::other(format!( + "Unsupported server-side encryption algorithm: {}", + algorithm.as_str() + )))); + } + + let algorithm_str = algorithm.as_str(); + let encryption_type = match algorithm_str { + "AES256" => EncryptionType::SseS3, + "aws:kms" => EncryptionType::SseKms, + _ => EncryptionType::SseS3, + }; + + let mut context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string()); + if content_size >= 0 { + context = context.with_size(content_size as u64); + } + + let mut kms_key_candidate = kms_key_id.map(|s| s.to_string()); + if kms_key_candidate.is_none() { + kms_key_candidate = service.get_default_key_id().cloned(); + } + + let kms_key_to_use = kms_key_candidate + .clone() + .ok_or_else(|| ApiError::from(StorageError::other("No KMS key available for managed server-side encryption")))?; + + let (data_key, encrypted_data_key) = service + .create_data_key(&kms_key_candidate, &context) + .await + .map_err(|e| ApiError::from(StorageError::other(format!("Failed to create data key: {e}"))))?; + + let encryption_metadata = EncryptionMetadata { + algorithm: algorithm_str.to_string(), + key_id: kms_key_to_use.clone(), + key_version: 1, + iv: data_key.nonce.to_vec(), + tag: None, + encryption_context: context.encryption_context.clone(), + encrypted_at: Utc::now(), + original_size: if content_size >= 0 { content_size as u64 } else { 0 }, + encrypted_data_key, + }; + + let mut metadata = service.metadata_to_headers(&encryption_metadata); + metadata.insert("x-rustfs-encryption-original-size".to_string(), encryption_metadata.original_size.to_string()); + + // Handle part-specific nonce if needed + let nonce = if let Some(part_num) = part_number { + derive_part_nonce(data_key.nonce, part_num) + } else { + data_key.nonce + }; + + Ok(EncryptionMaterial { + key_bytes: data_key.plaintext_key, + nonce, + metadata, + encryption_type, + kms_key_id: Some(kms_key_to_use), + }) +} + +async fn apply_managed_decryption_material( + bucket: &str, + key: &str, + metadata: &HashMap, + part_number: Option, +) -> Result, ApiError> { + if !metadata.contains_key("x-rustfs-encryption-key") { + return Ok(None); + } + + let Some(service) = get_global_encryption_service().await else { + return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized"))); + }; + + let parsed = service + .headers_to_metadata(metadata) + .map_err(|e| ApiError::from(StorageError::other(format!("Failed to parse encryption metadata: {e}"))))?; + + if parsed.iv.len() != 12 { + return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes"))); + } + + let context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string()); + let data_key = service + .decrypt_data_key(&parsed.encrypted_data_key, &context) + .await + .map_err(|e| ApiError::from(StorageError::other(format!("Failed to decrypt data key: {e}"))))?; + + let key_bytes = data_key.plaintext_key; + let mut base_nonce = [0u8; 12]; + base_nonce.copy_from_slice(&parsed.iv[..12]); + + let nonce = if let Some(part_num) = part_number { + derive_part_nonce(base_nonce, part_num) + } else { + base_nonce + }; + + let original_size = metadata + .get("x-rustfs-encryption-original-size") + .and_then(|s| s.parse::().ok()); + + let encryption_type = match parsed.algorithm.as_str() { + "AES256" => EncryptionType::SseS3, + "aws:kms" => EncryptionType::SseKms, + _ => EncryptionType::SseS3, + }; + + Ok(Some(DecryptionMaterial { + key_bytes, + nonce, + original_size, + encryption_type, + })) +} + +// ============================================================================ +// Legacy Types (for backward compatibility) // ============================================================================ /// Material for managed server-side encryption (SSE-S3/SSE-KMS) @@ -104,7 +595,102 @@ pub struct SsecParams { } // ============================================================================ -// Managed SSE Functions (SSE-S3 / SSE-KMS) +// 测试用,自定义SSE CMK的实现 (SSE-S3 / SSE-KMS) Definitions +// ============================================================================ + +// 定义一个通过 bucket 和 key 获取 SSE DEK 的特性 +trait SseDekProvider { + /// Generate an SSE DEK + async fn generate_sse_dek( + &self, + bucket: &str, + key: &str, + kms_key_id: &str, + ) -> Result<(DataKey, Vec), ApiError>; + + /// Decrypt an SSE DEK (returns only plaintext key, nonce should be read from metadata) + async fn decrypt_sse_dek( + &self, + encrypted_dek: &[u8], + kms_key_id: &str, + ) -> Result<[u8; 32], ApiError>; +} + +// 实现一个通过 bucket 和 key(object key)获取 SSE DEK 测试用途的实现 +struct SimpleSseDekProvider { + cmk_ids: HashMap, +} + +// __RUSTFS_SSE_SIMPLE_CMK_ID 格式为:key-id1:key1,key-id2:key2,... + +impl SimpleSseDekProvider { + pub fn new() -> Self { + let cmk_id = std::env::var("__RUSTFS_SSE_SIMPLE_CMK_ID").unwrap_or_default(); + let cmk_ids = cmk_id.split(',').map(|s| s.split(':').collect()).collect(); + Self { cmk_ids } + } + + // 简单的加密DEK(仅用于测试,不做实际加密) + fn encrypt_dek(dek: [u8; 32]) -> Vec { + let mut encrypted_dek = vec![0u8; 32]; + encrypted_dek.copy_from_slice(&dek); + encrypted_dek + } + + // 简单的解密DEK(仅用于测试,不做实际解密) + fn decrypt_dek(encrypted_dek: &[u8]) -> [u8; 32] { + let mut dek = [0u8; 32]; + dek.copy_from_slice(encrypted_dek); + dek + } +} + +impl SseDekProvider for SimpleSseDekProvider { + async fn generate_sse_dek(&self, bucket: &str, key: &str, kms_key_id: &str) -> Result<(DataKey, Vec), ApiError> { + // 通过一个配置项获取 CMK ID + let _cmk_id = self + .cmk_ids + .get(kms_key_id) + .and_then(|s: &Vec<&str>| s.get(1).copied()) + .ok_or_else(|| ApiError::from(StorageError::other(format!("CMK ID not found: {}", kms_key_id))))?; + + // 随机生成一个32字节的数组作为数据密钥 + let mut dek = [0u8; 32]; + use rand::RngCore; + rand::thread_rng().fill_bytes(&mut dek); + + // 随机生成一个12字节的数组作为IV + let mut nonce = [0u8; 12]; + rand::thread_rng().fill_bytes(&mut nonce); + + // 加密数据密钥 + let encrypted_dek = Self::encrypt_dek(dek); + + // 返回数据密钥和IV + Ok(( + DataKey { + plaintext_key: dek, + nonce, + }, + encrypted_dek, + )) + } + + async fn decrypt_sse_dek(&self, encrypted_dek: &[u8], kms_key_id: &str) -> Result<[u8; 32], ApiError> { + // 通过一个配置项获取 CMK ID + let _cmk_id = self + .cmk_ids + .get(kms_key_id) + .and_then(|s: &Vec<&str>| s.get(1).copied()) + .ok_or_else(|| ApiError::from(StorageError::other(format!("CMK ID not found: {}", kms_key_id))))?; + + // 解密数据密钥 + Ok(Self::decrypt_dek(encrypted_dek)) + } +} + +// ============================================================================ +// Legacy Functions (SSE-S3 / SSE-KMS) // ============================================================================ /// Check if the algorithm is a managed SSE type (SSE-S3 or SSE-KMS) @@ -115,21 +701,7 @@ pub fn is_managed_sse(algorithm: &ServerSideEncryption) -> bool { /// Create managed encryption material for SSE-S3 or SSE-KMS /// -/// This function: -/// 1. Validates the encryption algorithm -/// 2. Creates an encryption context -/// 3. Generates a data key via KMS -/// 4. Prepares metadata headers for storage -/// -/// # Arguments -/// * `bucket` - Bucket name -/// * `key` - Object key -/// * `algorithm` - Encryption algorithm (AES256 or aws:kms) -/// * `kms_key_id` - Optional KMS key ID (uses default if None) -/// * `original_size` - Original object size before encryption -/// -/// # Returns -/// `ManagedEncryptionMaterial` containing data key, headers, and key ID +/// **DEPRECATED**: Use `apply_encryption()` instead for unified API pub async fn create_managed_encryption_material( bucket: &str, key: &str, @@ -193,18 +765,7 @@ pub async fn create_managed_encryption_material( /// Decrypt managed encryption key from object metadata /// -/// This function: -/// 1. Checks if object has managed encryption metadata -/// 2. Parses encryption metadata from headers -/// 3. Decrypts the data key using KMS -/// -/// # Arguments -/// * `bucket` - Bucket name -/// * `key` - Object key -/// * `metadata` - Object metadata containing encryption headers -/// -/// # Returns -/// `Some((key_bytes, nonce, original_size))` if object is encrypted, `None` otherwise +/// **DEPRECATED**: Use `apply_decryption()` instead for unified API pub async fn decrypt_managed_encryption_key( bucket: &str, key: &str, @@ -316,71 +877,61 @@ impl AsyncSeek for InMemoryAsyncReader { } } -/// Decrypt a multipart upload stream with managed SSE encryption +/// Decrypt multipart upload stream with managed encryption /// -/// This function: -/// 1. Reads each encrypted part from the stream -/// 2. Derives a unique nonce for each part -/// 3. Decrypts each part individually -/// 4. Concatenates all plaintext parts into a single buffer -/// -/// # Arguments -/// * `encrypted_stream` - Stream containing encrypted multipart data -/// * `parts` - Part info containing sizes and part numbers -/// * `key_bytes` - Decryption key -/// * `base_nonce` - Base nonce (unique nonce derived per part) -/// -/// # Returns -/// Tuple of (decrypted_reader, total_plaintext_size) +/// Decrypts a stream of encrypted parts by: +/// 1. Reading all parts into memory +/// 2. Deriving per-part nonces from base nonce +/// 3. Decrypting each part separately +/// 4. Concatenating decrypted data pub async fn decrypt_multipart_managed_stream( mut encrypted_stream: Box, parts: &[ObjectPartInfo], key_bytes: [u8; 32], base_nonce: [u8; 12], ) -> Result<(Box, i64), StorageError> { - let total_plain_capacity: usize = parts.iter().map(|part| part.actual_size.max(0) as usize).sum(); + let mut encrypted_data = Vec::new(); + tokio::io::AsyncReadExt::read_to_end(&mut encrypted_stream, &mut encrypted_data).await?; - let mut plaintext = Vec::with_capacity(total_plain_capacity); + let mut decrypted_parts = Vec::new(); + let mut offset = 0; - for part in parts { - if part.size == 0 { - continue; + for part_info in parts { + let part_size = part_info.actual_size as usize; + if offset + part_size > encrypted_data.len() { + return Err(StorageError::other("Encrypted data size mismatch with parts metadata")); } - let mut encrypted_part = vec![0u8; part.size]; - tokio::io::AsyncReadExt::read_exact(&mut encrypted_stream, &mut encrypted_part) - .await - .map_err(|e| StorageError::other(format!("failed to read encrypted multipart segment {}: {}", part.number, e)))?; + let part_data = &encrypted_data[offset..offset + part_size]; + let part_nonce = derive_part_nonce(base_nonce, part_info.part_number); - let part_nonce = derive_part_nonce(base_nonce, part.number); - let cursor = std::io::Cursor::new(encrypted_part); - let mut decrypt_reader = DecryptReader::new(WarpReader::new(cursor), key_bytes, part_nonce); + let mut decrypted_part = Vec::with_capacity(part_size); + let cursor = std::io::Cursor::new(part_data); + let decrypt_reader = DecryptReader::new(cursor, key_bytes, part_nonce); + let mut decrypt_reader = Box::pin(decrypt_reader); - tokio::io::AsyncReadExt::read_to_end(&mut decrypt_reader, &mut plaintext) - .await - .map_err(|e| StorageError::other(format!("failed to decrypt multipart segment {}: {}", part.number, e)))?; + tokio::io::AsyncReadExt::read_to_end(&mut decrypt_reader, &mut decrypted_part).await?; + decrypted_parts.push(decrypted_part); + offset += part_size; } - let total_plain_size = plaintext.len() as i64; - let reader = Box::new(WarpReader::new(InMemoryAsyncReader::new(plaintext))) as Box; + let all_decrypted = decrypted_parts.concat(); + let total_size = all_decrypted.len() as i64; - Ok((reader, total_plain_size)) + let reader: Box = Box::new(InMemoryAsyncReader::new(all_decrypted)); + Ok((reader, total_size)) } // ============================================================================ -// Customer-Provided Key (SSE-C) Functions +// SSE-C Functions // ============================================================================ /// Validate SSE-C parameters from client request /// -/// This function: -/// 1. Validates the algorithm is AES256 -/// 2. Decodes the Base64-encoded key -/// 3. Validates key length is 32 bytes -/// 4. Verifies MD5 hash matches -/// -/// # Arguments -/// * `params` - SSE-C parameters from client +/// Validates: +/// 1. Algorithm is "AES256" +/// 2. Key is valid Base64 and exactly 32 bytes +/// 3. MD5 hash matches the key /// /// # Returns /// `ValidatedSsecParams` with decoded key bytes @@ -439,14 +990,7 @@ pub fn generate_ssec_nonce(bucket: &str, key: &str) -> [u8; 12] { /// Apply SSE-C encryption to a reader /// -/// # Arguments -/// * `reader` - Input reader to encrypt -/// * `validated` - Validated SSE-C parameters -/// * `bucket` - Bucket name (for nonce generation) -/// * `key` - Object key (for nonce generation) -/// -/// # Returns -/// Encrypted reader wrapped in Box +/// **DEPRECATED**: Use `apply_encryption()` instead for unified API pub fn apply_ssec_encryption(reader: R, validated: &ValidatedSsecParams, bucket: &str, key: &str) -> Box> where R: Reader + 'static, @@ -457,14 +1001,7 @@ where /// Apply SSE-C decryption to a reader /// -/// # Arguments -/// * `reader` - Encrypted reader to decrypt -/// * `validated` - Validated SSE-C parameters -/// * `bucket` - Bucket name (for nonce generation) -/// * `key` - Object key (for nonce generation) -/// -/// # Returns -/// Decrypted reader wrapped in Box +/// **DEPRECATED**: Use `apply_decryption()` instead for unified API pub fn apply_ssec_decryption(reader: R, validated: &ValidatedSsecParams, bucket: &str, key: &str) -> Box> where R: Reader + 'static, @@ -490,18 +1027,14 @@ pub fn store_ssec_metadata(metadata: &mut HashMap, validated: &V /// /// Used during GetObject to ensure the client provided the correct key. pub fn verify_ssec_key_match(provided_md5: &str, stored_md5: Option<&String>) -> Result<(), ApiError> { - let Some(stored) = stored_md5 else { - return Err(ApiError::from(StorageError::other( - "Object encrypted with SSE-C but stored key MD5 not found", - ))); - }; - - if provided_md5 != stored { - error!("SSE-C key MD5 mismatch: provided='{}', stored='{}'", provided_md5, stored); - return Err(ApiError::from(StorageError::other("SSE-C key does not match object encryption key"))); + match stored_md5 { + Some(stored) if stored == provided_md5 => Ok(()), + Some(stored) => Err(ApiError::from(StorageError::other(format!( + "SSE-C key MD5 mismatch: provided '{}' but expected '{}'", + provided_md5, stored + )))), + None => Err(ApiError::from(StorageError::other("Object has no stored SSE-C key MD5"))), } - - Ok(()) } #[cfg(test)] @@ -512,21 +1045,21 @@ mod tests { fn test_is_managed_sse() { assert!(is_managed_sse(&ServerSideEncryption::from_static("AES256"))); assert!(is_managed_sse(&ServerSideEncryption::from_static("aws:kms"))); - assert!(!is_managed_sse(&ServerSideEncryption::from_static("invalid"))); } #[test] fn test_derive_part_nonce() { - let base_nonce: [u8; 12] = [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 1]; - let part1_nonce = derive_part_nonce(base_nonce, 1); - let part2_nonce = derive_part_nonce(base_nonce, 2); + let base = [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 10]; + let part1 = derive_part_nonce(base, 1); + let part2 = derive_part_nonce(base, 2); - // First 8 bytes should be the same - assert_eq!(&part1_nonce[..8], &base_nonce[..8]); - assert_eq!(&part2_nonce[..8], &base_nonce[..8]); + // First 8 bytes should be unchanged + assert_eq!(&base[..8], &part1[..8]); + assert_eq!(&base[..8], &part2[..8]); - // Last 4 bytes should be different (counter) - assert_ne!(&part1_nonce[8..], &part2_nonce[8..]); + // Last 4 bytes should be incremented + assert_ne!(&base[8..], &part1[8..]); + assert_ne!(&part1[8..], &part2[8..]); } #[test] @@ -535,46 +1068,41 @@ mod tests { let nonce2 = generate_ssec_nonce("bucket1", "key1"); let nonce3 = generate_ssec_nonce("bucket1", "key2"); - // Same bucket/key should generate same nonce + // Same inputs should produce same nonce assert_eq!(nonce1, nonce2); - // Different key should generate different nonce + // Different inputs should produce different nonce assert_ne!(nonce1, nonce3); - // Nonce should be 12 bytes + // Nonce should be exactly 12 bytes assert_eq!(nonce1.len(), 12); } #[test] fn test_validate_ssec_params_success() { - // Generate a valid 32-byte key - let key_bytes = [42u8; 32]; - let key_b64 = BASE64_STANDARD.encode(key_bytes); - let key_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0); + let key = BASE64_STANDARD.encode([42u8; 32]); + let key_md5 = BASE64_STANDARD.encode(md5::compute([42u8; 32]).0); let params = SsecParams { algorithm: "AES256".to_string(), - key: key_b64, + key, key_md5, }; let result = validate_ssec_params(¶ms); assert!(result.is_ok()); - let validated = result.unwrap(); - assert_eq!(validated.algorithm, "AES256"); - assert_eq!(validated.key_bytes, key_bytes); + assert_eq!(validated.key_bytes, [42u8; 32]); } #[test] fn test_validate_ssec_params_wrong_algorithm() { - let key_bytes = [42u8; 32]; - let key_b64 = BASE64_STANDARD.encode(key_bytes); - let key_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0); + let key = BASE64_STANDARD.encode([42u8; 32]); + let key_md5 = BASE64_STANDARD.encode(md5::compute([42u8; 32]).0); let params = SsecParams { - algorithm: "AES128".to_string(), - key: key_b64, + algorithm: "AES128".to_string(), // Wrong algorithm + key, key_md5, }; @@ -584,13 +1112,12 @@ mod tests { #[test] fn test_validate_ssec_params_wrong_key_length() { - let key_bytes = [42u8; 16]; // Wrong length - let key_b64 = BASE64_STANDARD.encode(key_bytes); - let key_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0); + let key = BASE64_STANDARD.encode([42u8; 16]); // Only 16 bytes + let key_md5 = BASE64_STANDARD.encode(md5::compute([42u8; 16]).0); let params = SsecParams { algorithm: "AES256".to_string(), - key: key_b64, + key, key_md5, }; @@ -600,14 +1127,13 @@ mod tests { #[test] fn test_validate_ssec_params_wrong_md5() { - let key_bytes = [42u8; 32]; - let key_b64 = BASE64_STANDARD.encode(key_bytes); - let wrong_md5 = "wrong_md5_hash_here=="; + let key = BASE64_STANDARD.encode([42u8; 32]); + let key_md5 = BASE64_STANDARD.encode([99u8; 16]); // Wrong MD5 let params = SsecParams { algorithm: "AES256".to_string(), - key: key_b64, - key_md5: wrong_md5.to_string(), + key, + key_md5, }; let result = validate_ssec_params(¶ms); @@ -617,7 +1143,7 @@ mod tests { #[test] fn test_strip_managed_encryption_metadata() { let mut metadata = HashMap::new(); - metadata.insert("x-amz-server-side-encryption".to_string(), "AES256".to_string()); + metadata.insert("x-amz-server-side-encryption".to_string(), "aws:kms".to_string()); metadata.insert("x-rustfs-encryption-key".to_string(), "encrypted_key".to_string()); metadata.insert("content-type".to_string(), "text/plain".to_string()); @@ -625,26 +1151,26 @@ mod tests { assert!(!metadata.contains_key("x-amz-server-side-encryption")); assert!(!metadata.contains_key("x-rustfs-encryption-key")); - assert!(metadata.contains_key("content-type")); // Should not be removed + assert!(metadata.contains_key("content-type")); } #[test] fn test_verify_ssec_key_match_success() { - let stored_md5 = "abc123".to_string(); - let result = verify_ssec_key_match("abc123", Some(&stored_md5)); + let md5 = "test_md5".to_string(); + let result = verify_ssec_key_match("test_md5", Some(&md5)); assert!(result.is_ok()); } #[test] fn test_verify_ssec_key_match_mismatch() { - let stored_md5 = "abc123".to_string(); - let result = verify_ssec_key_match("xyz789", Some(&stored_md5)); + let md5 = "stored_md5".to_string(); + let result = verify_ssec_key_match("provided_md5", Some(&md5)); assert!(result.is_err()); } #[test] fn test_verify_ssec_key_match_no_stored() { - let result = verify_ssec_key_match("abc123", None); + let result = verify_ssec_key_match("provided_md5", None); assert!(result.is_err()); } }