From 8588188cac878b991a2793852d162e048b2634f1 Mon Sep 17 00:00:00 2001 From: reatang Date: Thu, 15 Jan 2026 00:09:52 +0800 Subject: [PATCH] Refactor the relationship between SSE and KMS, and decouple them through interfaces. --- rustfs/src/storage/sse.rs | 264 +++++++++++++++++++++++++++++++++----- 1 file changed, 229 insertions(+), 35 deletions(-) diff --git a/rustfs/src/storage/sse.rs b/rustfs/src/storage/sse.rs index d9698ce6..a16dad40 100644 --- a/rustfs/src/storage/sse.rs +++ b/rustfs/src/storage/sse.rs @@ -76,6 +76,7 @@ use aes_gcm::{ Aes256Gcm, Key, Nonce, aead::{Aead, KeyInit}, }; +use async_trait::async_trait; use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD}; use chrono::Utc; use rand::RngCore; @@ -89,6 +90,7 @@ use rustfs_kms::{ use rustfs_rio::{DecryptReader, EncryptReader, HardLimitReader, Reader, WarpReader}; use s3s::dto::ServerSideEncryption; use std::collections::HashMap; +use std::sync::{Arc, OnceLock}; use tokio::io::{AsyncRead, AsyncSeek}; use tracing::{debug, error, warn}; @@ -600,10 +602,6 @@ async fn apply_managed_encryption_material( // During UploadPart, we use the same base nonce with incremented counter // This is handled externally, so here we just generate the base material - let Some(service) = get_global_encryption_service().await else { - return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized"))); - }; - if !is_managed_sse(&algorithm) { return Err(ApiError::from(StorageError::other(format!( "Unsupported server-side encryption algorithm: {}", @@ -623,17 +621,23 @@ async fn apply_managed_encryption_material( context = context.with_size(content_size as u64); } + // Determine KMS key ID to use let mut kms_key_candidate = kms_key_id.clone().map(|s| s.to_string()); if kms_key_candidate.is_none() { - kms_key_candidate = service.get_default_key_id().cloned(); + // Try to get default key from KMS service (if available) + if let Some(service) = get_global_encryption_service().await { + kms_key_candidate = service.get_default_key_id().cloned(); + } } let kms_key_to_use = kms_key_candidate .clone() .ok_or_else(|| ApiError::from(StorageError::other("No KMS key available for managed server-side encryption")))?; - let (data_key, encrypted_data_key) = service - .create_data_key(&kms_key_candidate, &context) + // Use factory pattern to get provider (test or production mode) + let provider = get_sse_dek_provider().await?; + let (data_key, encrypted_data_key) = provider + .generate_sse_dek(bucket, key, &kms_key_to_use) .await .map_err(|e| ApiError::from(StorageError::other(format!("Failed to create data key: {e}"))))?; @@ -649,7 +653,19 @@ async fn apply_managed_encryption_material( encrypted_data_key, }; - let mut metadata = service.metadata_to_headers(&encryption_metadata); + // Build metadata headers + let mut metadata = HashMap::new(); + + // Try to use service for metadata formatting if available, otherwise build manually + if let Some(service) = get_global_encryption_service().await { + metadata = service.metadata_to_headers(&encryption_metadata); + } else { + // Manual metadata building for test mode + metadata.insert("x-rustfs-encryption-key".to_string(), BASE64_STANDARD.encode(&encryption_metadata.encrypted_data_key)); + metadata.insert("x-rustfs-encryption-iv".to_string(), BASE64_STANDARD.encode(&encryption_metadata.iv)); + metadata.insert("x-rustfs-encryption-algorithm".to_string(), encryption_metadata.algorithm.clone()); + } + metadata.insert( "x-rustfs-encryption-original-size".to_string(), encryption_metadata.original_size.to_string(), @@ -688,27 +704,61 @@ async fn apply_managed_decryption_material( return Ok(None); } - let Some(service) = get_global_encryption_service().await else { - return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized"))); + // Parse metadata - try using service if available, otherwise parse manually + let (encrypted_data_key, iv, algorithm) = if let Some(service) = get_global_encryption_service().await { + // Production mode: use service for metadata parsing + let parsed = service + .headers_to_metadata(metadata) + .map_err(|e| ApiError::from(StorageError::other(format!("Failed to parse encryption metadata: {e}"))))?; + + if parsed.iv.len() != 12 { + return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes"))); + } + + (parsed.encrypted_data_key, parsed.iv, parsed.algorithm) + } else { + // Test mode: parse metadata manually + let encrypted_key_b64 = metadata + .get("x-rustfs-encryption-key") + .ok_or_else(|| ApiError::from(StorageError::other("Missing encrypted key in metadata")))?; + let encrypted_data_key = BASE64_STANDARD + .decode(encrypted_key_b64) + .map_err(|e| ApiError::from(StorageError::other(format!("Failed to decode encrypted key: {e}"))))?; + + let iv_b64 = metadata + .get("x-rustfs-encryption-iv") + .ok_or_else(|| ApiError::from(StorageError::other("Missing IV in metadata")))?; + let iv = BASE64_STANDARD + .decode(iv_b64) + .map_err(|e| ApiError::from(StorageError::other(format!("Failed to decode IV: {e}"))))?; + + if iv.len() != 12 { + return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes"))); + } + + let algorithm = metadata + .get("x-rustfs-encryption-algorithm") + .cloned() + .unwrap_or_else(|| "AES256".to_string()); + + (encrypted_data_key, iv, algorithm) }; - let parsed = service - .headers_to_metadata(metadata) - .map_err(|e| ApiError::from(StorageError::other(format!("Failed to parse encryption metadata: {e}"))))?; + // Extract KMS key ID from metadata (optional, used for provider context) + let kms_key_id = metadata + .get("x-amz-server-side-encryption-aws-kms-key-id") + .cloned() + .unwrap_or_else(|| "default".to_string()); - if parsed.iv.len() != 12 { - return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes"))); - } - - let context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string()); - let data_key = service - .decrypt_data_key(&parsed.encrypted_data_key, &context) + // Use factory pattern to get provider (test or production mode) + let provider = get_sse_dek_provider().await?; + let key_bytes = provider + .decrypt_sse_dek(&encrypted_data_key, &kms_key_id) .await .map_err(|e| ApiError::from(StorageError::other(format!("Failed to decrypt data key: {e}"))))?; - let key_bytes = data_key.plaintext_key; let mut base_nonce = [0u8; 12]; - base_nonce.copy_from_slice(&parsed.iv[..12]); + base_nonce.copy_from_slice(&iv[..12]); let nonce = if let Some(part_num) = part_number { derive_part_nonce(base_nonce, part_num) @@ -720,7 +770,7 @@ async fn apply_managed_decryption_material( .get("x-rustfs-encryption-original-size") .and_then(|s| s.parse::().ok()); - let encryption_type = match parsed.algorithm.as_str() { + let encryption_type = match algorithm.as_str() { "AES256" => EncryptionType::SseS3, "aws:kms" => EncryptionType::SseKms, _ => EncryptionType::SseS3, @@ -774,18 +824,88 @@ pub struct SsecParams { } // ============================================================================ -// Test purpose, custom SSE CMK implementation (SSE-S3 / SSE-KMS) Definitions +// SSE DEK Provider Abstraction (Factory Pattern) // ============================================================================ -// Define a trait to get SSE DEK by bucket and key -trait SseDekProvider { - /// Generate an SSE DEK +/// Trait for SSE data encryption key management +/// Abstracts the source of encryption keys (KMS, test provider, etc.) +#[async_trait] +pub trait SseDekProvider: Send + Sync { + /// Generate an SSE data encryption key async fn generate_sse_dek(&self, bucket: &str, key: &str, kms_key_id: &str) -> Result<(DataKey, Vec), ApiError>; - /// Decrypt an SSE DEK (returns only plaintext key, nonce should be read from metadata) + /// Decrypt an SSE data encryption key (returns only plaintext key, nonce should be read from metadata) async fn decrypt_sse_dek(&self, encrypted_dek: &[u8], kms_key_id: &str) -> Result<[u8; 32], ApiError>; } +// ============================================================================ +// Production KMS-backed DEK Provider +// ============================================================================ + +/// Production KMS-backed DEK provider +/// Wraps the global ObjectEncryptionService to provide SSE DEK operations +struct KmsSseDekProvider { + service: Arc, +} + +impl KmsSseDekProvider { + /// Create a new KMS-backed provider + pub async fn new() -> Result { + let service = get_global_encryption_service() + .await + .ok_or_else(|| ApiError::from(StorageError::other( + "KMS encryption service is not initialized" + )))?; + Ok(Self { service }) + } +} + +#[async_trait] +impl SseDekProvider for KmsSseDekProvider { + async fn generate_sse_dek( + &self, + bucket: &str, + key: &str, + kms_key_id: &str, + ) -> Result<(DataKey, Vec), ApiError> { + let context = ObjectEncryptionContext::new( + bucket.to_string(), + key.to_string() + ); + + let kms_key_option = Some(kms_key_id.to_string()); + let (data_key, encrypted_data_key) = self.service + .create_data_key(&kms_key_option, &context) + .await + .map_err(|e| ApiError::from(StorageError::other( + format!("Failed to create data key: {}", e) + )))?; + + Ok((data_key, encrypted_data_key)) + } + + async fn decrypt_sse_dek( + &self, + encrypted_dek: &[u8], + _kms_key_id: &str, + ) -> Result<[u8; 32], ApiError> { + // Create a minimal context for decryption + let context = ObjectEncryptionContext::new("".to_string(), "".to_string()); + let data_key = self.service + .decrypt_data_key(encrypted_dek, &context) + .await + .map_err(|e| ApiError::from(StorageError::other( + format!("Failed to decrypt data key: {}", e) + )))?; + + Ok(data_key.plaintext_key) + } +} + +// ============================================================================ +// Test/Simple DEK Provider +// ============================================================================ + // Implement a simple SSE DEK provider for testing purposes struct SimpleSseDekProvider { cmk_ids: HashMap, @@ -794,6 +914,12 @@ struct SimpleSseDekProvider { // __RUSTFS_SSE_SIMPLE_CMK_ID format: key-id1:base64_key1,key-id2:base64_key2,... impl SimpleSseDekProvider { + /// Create a SimpleSseDekProvider with predefined keys (for testing) + #[cfg(test)] + pub fn new_with_keys(cmk_ids: HashMap) -> Self { + Self { cmk_ids } + } + pub fn new() -> Self { let cmk_id = std::env::var("__RUSTFS_SSE_SIMPLE_CMK_ID").unwrap_or_else(|_| "".to_string()); @@ -888,6 +1014,7 @@ impl SimpleSseDekProvider { } } +#[async_trait] impl SseDekProvider for SimpleSseDekProvider { async fn generate_sse_dek(&self, _bucket: &str, _key: &str, kms_key_id: &str) -> Result<(DataKey, Vec), ApiError> { let cmk_value = self @@ -933,6 +1060,65 @@ impl SseDekProvider for SimpleSseDekProvider { } } +// ============================================================================ +// Factory Function for SSE DEK Provider +// ============================================================================ + +/// Global SSE DEK provider cache +static GLOBAL_SSE_DEK_PROVIDER: OnceLock> = OnceLock::new(); + +/// Get or initialize the global SSE DEK provider +/// +/// Factory function that automatically selects the appropriate provider: +/// - If `__RUSTFS_SSE_SIMPLE_CMK_ID` environment variable exists: use SimpleSseDekProvider (test mode) +/// - Otherwise: use KmsSseDekProvider (production mode with real KMS) +/// +/// # Returns +/// Arc to the global SSE DEK provider instance +/// +/// # Example +/// ```rust,ignore +/// let provider = get_sse_dek_provider().await?; +/// let (data_key, encrypted_dek) = provider +/// .generate_sse_dek("bucket", "key", "kms-key-id") +/// .await?; +/// ``` +pub async fn get_sse_dek_provider() -> Result, ApiError> { + // Check if already initialized + if let Some(provider) = GLOBAL_SSE_DEK_PROVIDER.get() { + return Ok(provider.clone()); + } + + // Determine provider based on environment variable + let provider: Arc = if std::env::var("__RUSTFS_SSE_SIMPLE_CMK_ID").is_ok() { + debug!("Using SimpleSseDekProvider (test mode) based on __RUSTFS_SSE_SIMPLE_CMK_ID"); + Arc::new(SimpleSseDekProvider::new()) + } else { + debug!("Using KmsSseDekProvider (production mode)"); + Arc::new(KmsSseDekProvider::new().await?) + }; + + // Store in global cache + GLOBAL_SSE_DEK_PROVIDER + .set(provider.clone()) + .map_err(|_| ApiError::from(StorageError::other( + "Failed to initialize global SSE DEK provider (already set)" + )))?; + + Ok(provider) +} + +/// Reset the global SSE DEK provider (for testing only) +/// +/// Note: OnceLock doesn't support reset in stable Rust. +/// Tests should set environment variables before first call to `get_sse_dek_provider()`. +#[cfg(test)] +#[allow(dead_code)] +pub fn reset_sse_dek_provider() { + // OnceLock doesn't support reset - this is a documentation placeholder + // Consider using arc_swap::ArcSwap if runtime reset is needed +} + // ============================================================================ // Legacy Functions (SSE-S3 / SSE-KMS) // ============================================================================ @@ -1428,7 +1614,9 @@ mod tests { use tokio::io::AsyncReadExt; // 1. Setup: Create SimpleSseDekProvider with test CMK - let provider = SimpleSseDekProvider::new(); + let mut test_keys = HashMap::new(); + test_keys.insert("test-key".to_string(), [42u8; 32]); + let provider = SimpleSseDekProvider::new_with_keys(test_keys); // 2. Generate a data encryption key let bucket = "test-bucket"; @@ -1494,8 +1682,10 @@ mod tests { use std::io::Cursor; use tokio::io::AsyncReadExt; - // Test with larger data to ensure streaming works correctly - let provider = SimpleSseDekProvider::new(); + // 1. Setup: Create SimpleSseDekProvider with test CMK + let mut test_keys = HashMap::new(); + test_keys.insert("test-key".to_string(), [42u8; 32]); + let provider = SimpleSseDekProvider::new_with_keys(test_keys); let bucket = "test-bucket"; let key = "test-key-large"; @@ -1545,8 +1735,10 @@ mod tests { use std::io::Cursor; use tokio::io::AsyncReadExt; - // Verify that different nonces produce different ciphertext - let provider = SimpleSseDekProvider::new(); + // 1. Setup: Create SimpleSseDekProvider with test CMK + let mut test_keys = HashMap::new(); + test_keys.insert("test-key".to_string(), [42u8; 32]); + let provider = SimpleSseDekProvider::new_with_keys(test_keys); let bucket = "test-bucket"; let key = "test-key"; @@ -1595,8 +1787,10 @@ mod tests { use std::io::Cursor; use tokio::io::AsyncReadExt; - // Test the full cycle: generate -> encrypt DEK -> decrypt DEK -> use for data encryption - let provider = SimpleSseDekProvider::new(); + // 1. Setup: Create SimpleSseDekProvider with test CMK + let mut test_keys = HashMap::new(); + test_keys.insert("test-key".to_string(), [42u8; 32]); + let provider = SimpleSseDekProvider::new_with_keys(test_keys); let bucket = "test-bucket"; let key = "test-key";