Refactor the relationship between SSE and KMS, and decouple them through interfaces.

This commit is contained in:
reatang
2026-01-15 00:09:52 +08:00
parent d00ce55047
commit 8588188cac

View File

@@ -76,6 +76,7 @@ use aes_gcm::{
Aes256Gcm, Key, Nonce,
aead::{Aead, KeyInit},
};
use async_trait::async_trait;
use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
use chrono::Utc;
use rand::RngCore;
@@ -89,6 +90,7 @@ use rustfs_kms::{
use rustfs_rio::{DecryptReader, EncryptReader, HardLimitReader, Reader, WarpReader};
use s3s::dto::ServerSideEncryption;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use tokio::io::{AsyncRead, AsyncSeek};
use tracing::{debug, error, warn};
@@ -600,10 +602,6 @@ async fn apply_managed_encryption_material(
// During UploadPart, we use the same base nonce with incremented counter
// This is handled externally, so here we just generate the base material
let Some(service) = get_global_encryption_service().await else {
return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized")));
};
if !is_managed_sse(&algorithm) {
return Err(ApiError::from(StorageError::other(format!(
"Unsupported server-side encryption algorithm: {}",
@@ -623,17 +621,23 @@ async fn apply_managed_encryption_material(
context = context.with_size(content_size as u64);
}
// Determine KMS key ID to use
let mut kms_key_candidate = kms_key_id.clone().map(|s| s.to_string());
if kms_key_candidate.is_none() {
kms_key_candidate = service.get_default_key_id().cloned();
// Try to get default key from KMS service (if available)
if let Some(service) = get_global_encryption_service().await {
kms_key_candidate = service.get_default_key_id().cloned();
}
}
let kms_key_to_use = kms_key_candidate
.clone()
.ok_or_else(|| ApiError::from(StorageError::other("No KMS key available for managed server-side encryption")))?;
let (data_key, encrypted_data_key) = service
.create_data_key(&kms_key_candidate, &context)
// Use factory pattern to get provider (test or production mode)
let provider = get_sse_dek_provider().await?;
let (data_key, encrypted_data_key) = provider
.generate_sse_dek(bucket, key, &kms_key_to_use)
.await
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to create data key: {e}"))))?;
@@ -649,7 +653,19 @@ async fn apply_managed_encryption_material(
encrypted_data_key,
};
let mut metadata = service.metadata_to_headers(&encryption_metadata);
// Build metadata headers
let mut metadata = HashMap::new();
// Try to use service for metadata formatting if available, otherwise build manually
if let Some(service) = get_global_encryption_service().await {
metadata = service.metadata_to_headers(&encryption_metadata);
} else {
// Manual metadata building for test mode
metadata.insert("x-rustfs-encryption-key".to_string(), BASE64_STANDARD.encode(&encryption_metadata.encrypted_data_key));
metadata.insert("x-rustfs-encryption-iv".to_string(), BASE64_STANDARD.encode(&encryption_metadata.iv));
metadata.insert("x-rustfs-encryption-algorithm".to_string(), encryption_metadata.algorithm.clone());
}
metadata.insert(
"x-rustfs-encryption-original-size".to_string(),
encryption_metadata.original_size.to_string(),
@@ -688,27 +704,61 @@ async fn apply_managed_decryption_material(
return Ok(None);
}
let Some(service) = get_global_encryption_service().await else {
return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized")));
// Parse metadata - try using service if available, otherwise parse manually
let (encrypted_data_key, iv, algorithm) = if let Some(service) = get_global_encryption_service().await {
// Production mode: use service for metadata parsing
let parsed = service
.headers_to_metadata(metadata)
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to parse encryption metadata: {e}"))))?;
if parsed.iv.len() != 12 {
return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes")));
}
(parsed.encrypted_data_key, parsed.iv, parsed.algorithm)
} else {
// Test mode: parse metadata manually
let encrypted_key_b64 = metadata
.get("x-rustfs-encryption-key")
.ok_or_else(|| ApiError::from(StorageError::other("Missing encrypted key in metadata")))?;
let encrypted_data_key = BASE64_STANDARD
.decode(encrypted_key_b64)
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to decode encrypted key: {e}"))))?;
let iv_b64 = metadata
.get("x-rustfs-encryption-iv")
.ok_or_else(|| ApiError::from(StorageError::other("Missing IV in metadata")))?;
let iv = BASE64_STANDARD
.decode(iv_b64)
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to decode IV: {e}"))))?;
if iv.len() != 12 {
return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes")));
}
let algorithm = metadata
.get("x-rustfs-encryption-algorithm")
.cloned()
.unwrap_or_else(|| "AES256".to_string());
(encrypted_data_key, iv, algorithm)
};
let parsed = service
.headers_to_metadata(metadata)
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to parse encryption metadata: {e}"))))?;
// Extract KMS key ID from metadata (optional, used for provider context)
let kms_key_id = metadata
.get("x-amz-server-side-encryption-aws-kms-key-id")
.cloned()
.unwrap_or_else(|| "default".to_string());
if parsed.iv.len() != 12 {
return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes")));
}
let context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string());
let data_key = service
.decrypt_data_key(&parsed.encrypted_data_key, &context)
// Use factory pattern to get provider (test or production mode)
let provider = get_sse_dek_provider().await?;
let key_bytes = provider
.decrypt_sse_dek(&encrypted_data_key, &kms_key_id)
.await
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to decrypt data key: {e}"))))?;
let key_bytes = data_key.plaintext_key;
let mut base_nonce = [0u8; 12];
base_nonce.copy_from_slice(&parsed.iv[..12]);
base_nonce.copy_from_slice(&iv[..12]);
let nonce = if let Some(part_num) = part_number {
derive_part_nonce(base_nonce, part_num)
@@ -720,7 +770,7 @@ async fn apply_managed_decryption_material(
.get("x-rustfs-encryption-original-size")
.and_then(|s| s.parse::<i64>().ok());
let encryption_type = match parsed.algorithm.as_str() {
let encryption_type = match algorithm.as_str() {
"AES256" => EncryptionType::SseS3,
"aws:kms" => EncryptionType::SseKms,
_ => EncryptionType::SseS3,
@@ -774,18 +824,88 @@ pub struct SsecParams {
}
// ============================================================================
// Test purpose, custom SSE CMK implementation (SSE-S3 / SSE-KMS) Definitions
// SSE DEK Provider Abstraction (Factory Pattern)
// ============================================================================
// Define a trait to get SSE DEK by bucket and key
trait SseDekProvider {
/// Generate an SSE DEK
/// Trait for SSE data encryption key management
/// Abstracts the source of encryption keys (KMS, test provider, etc.)
#[async_trait]
pub trait SseDekProvider: Send + Sync {
/// Generate an SSE data encryption key
async fn generate_sse_dek(&self, bucket: &str, key: &str, kms_key_id: &str) -> Result<(DataKey, Vec<u8>), ApiError>;
/// Decrypt an SSE DEK (returns only plaintext key, nonce should be read from metadata)
/// Decrypt an SSE data encryption key (returns only plaintext key, nonce should be read from metadata)
async fn decrypt_sse_dek(&self, encrypted_dek: &[u8], kms_key_id: &str) -> Result<[u8; 32], ApiError>;
}
// ============================================================================
// Production KMS-backed DEK Provider
// ============================================================================
/// Production KMS-backed DEK provider
/// Wraps the global ObjectEncryptionService to provide SSE DEK operations
struct KmsSseDekProvider {
service: Arc<rustfs_kms::service::ObjectEncryptionService>,
}
impl KmsSseDekProvider {
/// Create a new KMS-backed provider
pub async fn new() -> Result<Self, ApiError> {
let service = get_global_encryption_service()
.await
.ok_or_else(|| ApiError::from(StorageError::other(
"KMS encryption service is not initialized"
)))?;
Ok(Self { service })
}
}
#[async_trait]
impl SseDekProvider for KmsSseDekProvider {
async fn generate_sse_dek(
&self,
bucket: &str,
key: &str,
kms_key_id: &str,
) -> Result<(DataKey, Vec<u8>), ApiError> {
let context = ObjectEncryptionContext::new(
bucket.to_string(),
key.to_string()
);
let kms_key_option = Some(kms_key_id.to_string());
let (data_key, encrypted_data_key) = self.service
.create_data_key(&kms_key_option, &context)
.await
.map_err(|e| ApiError::from(StorageError::other(
format!("Failed to create data key: {}", e)
)))?;
Ok((data_key, encrypted_data_key))
}
async fn decrypt_sse_dek(
&self,
encrypted_dek: &[u8],
_kms_key_id: &str,
) -> Result<[u8; 32], ApiError> {
// Create a minimal context for decryption
let context = ObjectEncryptionContext::new("".to_string(), "".to_string());
let data_key = self.service
.decrypt_data_key(encrypted_dek, &context)
.await
.map_err(|e| ApiError::from(StorageError::other(
format!("Failed to decrypt data key: {}", e)
)))?;
Ok(data_key.plaintext_key)
}
}
// ============================================================================
// Test/Simple DEK Provider
// ============================================================================
// Implement a simple SSE DEK provider for testing purposes
struct SimpleSseDekProvider {
cmk_ids: HashMap<String, [u8; 32]>,
@@ -794,6 +914,12 @@ struct SimpleSseDekProvider {
// __RUSTFS_SSE_SIMPLE_CMK_ID format: key-id1:base64_key1,key-id2:base64_key2,...
impl SimpleSseDekProvider {
/// Create a SimpleSseDekProvider with predefined keys (for testing)
#[cfg(test)]
pub fn new_with_keys(cmk_ids: HashMap<String, [u8; 32]>) -> Self {
Self { cmk_ids }
}
pub fn new() -> Self {
let cmk_id = std::env::var("__RUSTFS_SSE_SIMPLE_CMK_ID").unwrap_or_else(|_| "".to_string());
@@ -888,6 +1014,7 @@ impl SimpleSseDekProvider {
}
}
#[async_trait]
impl SseDekProvider for SimpleSseDekProvider {
async fn generate_sse_dek(&self, _bucket: &str, _key: &str, kms_key_id: &str) -> Result<(DataKey, Vec<u8>), ApiError> {
let cmk_value = self
@@ -933,6 +1060,65 @@ impl SseDekProvider for SimpleSseDekProvider {
}
}
// ============================================================================
// Factory Function for SSE DEK Provider
// ============================================================================
/// Global SSE DEK provider cache
static GLOBAL_SSE_DEK_PROVIDER: OnceLock<Arc<dyn SseDekProvider>> = OnceLock::new();
/// Get or initialize the global SSE DEK provider
///
/// Factory function that automatically selects the appropriate provider:
/// - If `__RUSTFS_SSE_SIMPLE_CMK_ID` environment variable exists: use SimpleSseDekProvider (test mode)
/// - Otherwise: use KmsSseDekProvider (production mode with real KMS)
///
/// # Returns
/// Arc to the global SSE DEK provider instance
///
/// # Example
/// ```rust,ignore
/// let provider = get_sse_dek_provider().await?;
/// let (data_key, encrypted_dek) = provider
/// .generate_sse_dek("bucket", "key", "kms-key-id")
/// .await?;
/// ```
pub async fn get_sse_dek_provider() -> Result<Arc<dyn SseDekProvider>, ApiError> {
// Check if already initialized
if let Some(provider) = GLOBAL_SSE_DEK_PROVIDER.get() {
return Ok(provider.clone());
}
// Determine provider based on environment variable
let provider: Arc<dyn SseDekProvider> = if std::env::var("__RUSTFS_SSE_SIMPLE_CMK_ID").is_ok() {
debug!("Using SimpleSseDekProvider (test mode) based on __RUSTFS_SSE_SIMPLE_CMK_ID");
Arc::new(SimpleSseDekProvider::new())
} else {
debug!("Using KmsSseDekProvider (production mode)");
Arc::new(KmsSseDekProvider::new().await?)
};
// Store in global cache
GLOBAL_SSE_DEK_PROVIDER
.set(provider.clone())
.map_err(|_| ApiError::from(StorageError::other(
"Failed to initialize global SSE DEK provider (already set)"
)))?;
Ok(provider)
}
/// Reset the global SSE DEK provider (for testing only)
///
/// Note: OnceLock doesn't support reset in stable Rust.
/// Tests should set environment variables before first call to `get_sse_dek_provider()`.
#[cfg(test)]
#[allow(dead_code)]
pub fn reset_sse_dek_provider() {
// OnceLock doesn't support reset - this is a documentation placeholder
// Consider using arc_swap::ArcSwap if runtime reset is needed
}
// ============================================================================
// Legacy Functions (SSE-S3 / SSE-KMS)
// ============================================================================
@@ -1428,7 +1614,9 @@ mod tests {
use tokio::io::AsyncReadExt;
// 1. Setup: Create SimpleSseDekProvider with test CMK
let provider = SimpleSseDekProvider::new();
let mut test_keys = HashMap::new();
test_keys.insert("test-key".to_string(), [42u8; 32]);
let provider = SimpleSseDekProvider::new_with_keys(test_keys);
// 2. Generate a data encryption key
let bucket = "test-bucket";
@@ -1494,8 +1682,10 @@ mod tests {
use std::io::Cursor;
use tokio::io::AsyncReadExt;
// Test with larger data to ensure streaming works correctly
let provider = SimpleSseDekProvider::new();
// 1. Setup: Create SimpleSseDekProvider with test CMK
let mut test_keys = HashMap::new();
test_keys.insert("test-key".to_string(), [42u8; 32]);
let provider = SimpleSseDekProvider::new_with_keys(test_keys);
let bucket = "test-bucket";
let key = "test-key-large";
@@ -1545,8 +1735,10 @@ mod tests {
use std::io::Cursor;
use tokio::io::AsyncReadExt;
// Verify that different nonces produce different ciphertext
let provider = SimpleSseDekProvider::new();
// 1. Setup: Create SimpleSseDekProvider with test CMK
let mut test_keys = HashMap::new();
test_keys.insert("test-key".to_string(), [42u8; 32]);
let provider = SimpleSseDekProvider::new_with_keys(test_keys);
let bucket = "test-bucket";
let key = "test-key";
@@ -1595,8 +1787,10 @@ mod tests {
use std::io::Cursor;
use tokio::io::AsyncReadExt;
// Test the full cycle: generate -> encrypt DEK -> decrypt DEK -> use for data encryption
let provider = SimpleSseDekProvider::new();
// 1. Setup: Create SimpleSseDekProvider with test CMK
let mut test_keys = HashMap::new();
test_keys.insert("test-key".to_string(), [42u8; 32]);
let provider = SimpleSseDekProvider::new_with_keys(test_keys);
let bucket = "test-bucket";
let key = "test-key";