Completely extract the SSE layer from the business logic.

This commit is contained in:
reatang
2026-01-11 18:15:33 +08:00
parent 09a90058ff
commit b4c436ffe0
10 changed files with 808 additions and 445 deletions

View File

@@ -24,12 +24,12 @@
//!
//! Run with: `cargo run --example demo1`
use std::fs;
use rustfs_kms::{
init_global_kms_service_manager, CreateKeyRequest, DescribeKeyRequest, EncryptionAlgorithm,
GenerateDataKeyRequest, KmsConfig, KeySpec, KeyUsage, ListKeysRequest,
CreateKeyRequest, DescribeKeyRequest, EncryptionAlgorithm, GenerateDataKeyRequest, KeySpec, KeyUsage, KmsConfig,
ListKeysRequest, init_global_kms_service_manager,
};
use std::collections::HashMap;
use std::fs;
use std::io::Cursor;
use tokio::io::AsyncReadExt;
@@ -121,7 +121,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" - Master Key (CMK): Used to encrypt/decrypt data keys");
println!(" - Data Key (DEK): Used to encrypt/decrypt actual data");
println!(" In production, you can skip this and use encrypt_object() directly!\n");
let data_key_request = GenerateDataKeyRequest {
key_id: master_key_id.clone(),
key_spec: KeySpec::Aes256,
@@ -137,7 +137,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" ✓ Data key generated (for demonstration):");
println!(" - Master Key ID: {}", data_key_response.key_id);
println!(" - Data Key (plaintext) length: {} bytes", data_key_response.plaintext_key.len());
println!(" - Encrypted Data Key (ciphertext blob) length: {} bytes", data_key_response.ciphertext_blob.len());
println!(
" - Encrypted Data Key (ciphertext blob) length: {} bytes",
data_key_response.ciphertext_blob.len()
);
println!(" - Note: This data key is NOT used in Step 9 - encrypt_object() generates its own!\n");
// Step 9: Encrypt some data using high-level API
@@ -149,7 +152,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" 3. Uses the data key to encrypt the actual data");
println!(" 4. Stores the encrypted data key (ciphertext blob) in metadata");
println!(" You only need to provide the master_key_id - everything else is handled!\n");
let plaintext = b"Hello, RustFS KMS! This is a test message for encryption.";
println!(" Plaintext: {}", String::from_utf8_lossy(plaintext));
@@ -169,8 +172,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" ✓ Data encrypted:");
println!(" - Encrypted data length: {} bytes", encryption_result.ciphertext.len());
println!(" - Algorithm: {}", encryption_result.metadata.algorithm);
println!(" - Master Key ID: {} (used to encrypt the data key)", encryption_result.metadata.key_id);
println!(" - Encrypted Data Key length: {} bytes (stored in metadata)", encryption_result.metadata.encrypted_data_key.len());
println!(
" - Master Key ID: {} (used to encrypt the data key)",
encryption_result.metadata.key_id
);
println!(
" - Encrypted Data Key length: {} bytes (stored in metadata)",
encryption_result.metadata.encrypted_data_key.len()
);
println!(" - Original size: {} bytes\n", encryption_result.metadata.original_size);
// Step 10: Decrypt the data using high-level API
@@ -180,7 +189,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" 2. Uses master key to decrypt the data key");
println!(" 3. Uses the decrypted data key to decrypt the actual data");
println!(" You only need to provide the encrypted data and metadata!\n");
let mut decrypted_reader = encryption_service
.decrypt_object(
"demo-bucket",
@@ -240,4 +249,3 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

View File

@@ -31,8 +31,8 @@
//! RUSTFS_KMS_VAULT_ADDRESS=http://127.0.0.1:8200 RUSTFS_KMS_VAULT_TOKEN=your-token cargo run --example demo2
use rustfs_kms::{
init_global_kms_service_manager, CreateKeyRequest, DescribeKeyRequest, EncryptionAlgorithm,
GenerateDataKeyRequest, KmsConfig, KmsError, KeySpec, KeyUsage, ListKeysRequest,
CreateKeyRequest, DescribeKeyRequest, EncryptionAlgorithm, GenerateDataKeyRequest, KeySpec, KeyUsage, KmsConfig, KmsError,
ListKeysRequest, init_global_kms_service_manager,
};
use std::collections::HashMap;
use std::io::Cursor;
@@ -53,18 +53,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Step 2: Get Vault configuration from environment or use defaults
println!("2. Configuring Vault backend...");
let vault_address = std::env::var("RUSTFS_KMS_VAULT_ADDRESS")
.unwrap_or_else(|_| "http://127.0.0.1:8200".to_string());
let vault_token = std::env::var("RUSTFS_KMS_VAULT_TOKEN")
.unwrap_or_else(|_| {
println!(" ⚠️ No RUSTFS_KMS_VAULT_TOKEN found, using default 'dev-token'");
println!(" For production, set RUSTFS_KMS_VAULT_TOKEN environment variable");
"dev-token".to_string()
});
let vault_address = std::env::var("RUSTFS_KMS_VAULT_ADDRESS").unwrap_or_else(|_| "http://127.0.0.1:8200".to_string());
let vault_token = std::env::var("RUSTFS_KMS_VAULT_TOKEN").unwrap_or_else(|_| {
println!(" ⚠️ No RUSTFS_KMS_VAULT_TOKEN found, using default 'dev-token'");
println!(" For production, set RUSTFS_KMS_VAULT_TOKEN environment variable");
"dev-token".to_string()
});
let vault_url = Url::parse(&vault_address).map_err(|e| format!("Invalid Vault address '{}': {}", vault_address, e))?;
let vault_url = Url::parse(&vault_address)
.map_err(|e| format!("Invalid Vault address '{}': {}", vault_address, e))?;
println!(" ✓ Vault address: {}", vault_address);
println!(" ✓ Using token authentication\n");
@@ -160,7 +157,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" - Master Key (CMK): Stored in Vault, used to encrypt/decrypt data keys");
println!(" - Data Key (DEK): Generated per object, encrypted by master key");
println!(" In production, you can skip this and use encrypt_object() directly!\n");
let data_key_request = GenerateDataKeyRequest {
key_id: master_key_id.clone(),
key_spec: KeySpec::Aes256,
@@ -176,7 +173,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" ✓ Data key generated (for demonstration):");
println!(" - Master Key ID: {}", data_key_response.key_id);
println!(" - Data Key (plaintext) length: {} bytes", data_key_response.plaintext_key.len());
println!(" - Encrypted Data Key (ciphertext blob) length: {} bytes", data_key_response.ciphertext_blob.len());
println!(
" - Encrypted Data Key (ciphertext blob) length: {} bytes",
data_key_response.ciphertext_blob.len()
);
println!(" - Note: This data key is NOT used in Step 9 - encrypt_object() generates its own!\n");
// Step 9: Encrypt some data using high-level API
@@ -188,7 +188,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" 3. Uses the data key to encrypt the actual data");
println!(" 4. Stores the encrypted data key (ciphertext blob) in metadata");
println!(" You only need to provide the master_key_id - everything else is handled!\n");
let plaintext = b"Hello, RustFS KMS with Vault! This is a test message for encryption.";
println!(" Plaintext: {}", String::from_utf8_lossy(plaintext));
@@ -208,8 +208,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" ✓ Data encrypted:");
println!(" - Encrypted data length: {} bytes", encryption_result.ciphertext.len());
println!(" - Algorithm: {}", encryption_result.metadata.algorithm);
println!(" - Master Key ID: {} (stored in Vault, used to encrypt the data key)", encryption_result.metadata.key_id);
println!(" - Encrypted Data Key length: {} bytes (stored in metadata)", encryption_result.metadata.encrypted_data_key.len());
println!(
" - Master Key ID: {} (stored in Vault, used to encrypt the data key)",
encryption_result.metadata.key_id
);
println!(
" - Encrypted Data Key length: {} bytes (stored in metadata)",
encryption_result.metadata.encrypted_data_key.len()
);
println!(" - Original size: {} bytes\n", encryption_result.metadata.original_size);
// Step 10: Decrypt the data using high-level API
@@ -219,7 +225,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!(" 2. Uses master key from Vault to decrypt the data key");
println!(" 3. Uses the decrypted data key to decrypt the actual data");
println!(" You only need to provide the encrypted data and metadata!\n");
let mut decrypted_reader = encryption_service
.decrypt_object(
"demo-bucket",
@@ -284,4 +290,3 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

View File

@@ -63,7 +63,6 @@ struct StoredMasterKey {
nonce: Vec<u8>,
}
impl LocalKmsClient {
/// Create a new local KMS client
pub async fn new(config: LocalConfig) -> Result<Self> {
@@ -202,7 +201,6 @@ impl LocalKmsClient {
Ok(())
}
/// Get the actual key material for a master key
async fn get_key_material(&self, key_id: &str) -> Result<Vec<u8>> {
let key_path = self.master_key_path(key_id);

View File

@@ -41,7 +41,6 @@ pub struct VaultKmsClient {
dek_crypto: AesDekCrypto,
}
/// Key data stored in Vault
#[derive(Debug, Clone, Serialize, Deserialize)]
struct VaultKeyData {
@@ -113,7 +112,6 @@ impl VaultKmsClient {
format!("{}/{}", self.key_path_prefix, key_id)
}
/// Encrypt key material using Vault's transit engine
async fn encrypt_key_material(&self, key_material: &[u8]) -> Result<String> {
// For simplicity, we'll base64 encode the key material
@@ -133,7 +131,7 @@ impl VaultKmsClient {
/// Get the actual key material for a master key
async fn get_key_material(&self, key_id: &str) -> Result<Vec<u8>> {
let mut key_data = self.get_key_data(key_id).await?;
// If encrypted_key_material is empty, generate and store it (fix for old keys)
if key_data.encrypted_key_material.is_empty() {
warn!("Key {} has empty encrypted_key_material, generating and storing new key material", key_id);
@@ -143,7 +141,7 @@ impl VaultKmsClient {
self.store_key_data(key_id, &key_data).await?;
return Ok(key_material);
}
let key_material = match self.decrypt_key_material(&key_data.encrypted_key_material).await {
Ok(km) => km,
Err(e) => {
@@ -155,19 +153,22 @@ impl VaultKmsClient {
return Ok(new_key_material);
}
};
// Validate key material length (should be 32 bytes for AES-256)
if key_material.len() != 32 {
// Try to fix: generate new key material if length is wrong
warn!("Key {} has invalid key material length ({} bytes), generating new key material",
key_id, key_material.len());
warn!(
"Key {} has invalid key material length ({} bytes), generating new key material",
key_id,
key_material.len()
);
let new_key_material = generate_key_material(&key_data.algorithm)?;
key_data.encrypted_key_material = self.encrypt_key_material(&new_key_material).await?;
// Store the updated key data back to Vault
self.store_key_data(key_id, &key_data).await?;
return Ok(new_key_material);
}
Ok(key_material)
}
@@ -225,9 +226,11 @@ impl VaultKmsClient {
encrypted_key_material: existing_key_data.encrypted_key_material.clone(), // Preserve the key material
};
debug!("VaultKeyData tags before storage: {:?}, encrypted_key_material length: {}",
key_data.tags,
key_data.encrypted_key_material.len());
debug!(
"VaultKeyData tags before storage: {:?}, encrypted_key_material length: {}",
key_data.tags,
key_data.encrypted_key_material.len()
);
self.store_key_data(key_id, &key_data).await
}

View File

@@ -108,8 +108,8 @@ impl DekCrypto for AesDekCrypto {
}
// Create cipher from key material
let key = Key::<Aes256Gcm>::try_from(key_material)
.map_err(|_| KmsError::cryptographic_error("key", "Invalid key length"))?;
let key =
Key::<Aes256Gcm>::try_from(key_material).map_err(|_| KmsError::cryptographic_error("key", "Invalid key length"))?;
let cipher = Aes256Gcm::new(&key);
// Generate random nonce (12 bytes for GCM)
@@ -145,8 +145,8 @@ impl DekCrypto for AesDekCrypto {
}
// Create cipher from key material
let key = Key::<Aes256Gcm>::try_from(key_material)
.map_err(|_| KmsError::cryptographic_error("key", "Invalid key length"))?;
let key =
Key::<Aes256Gcm>::try_from(key_material).map_err(|_| KmsError::cryptographic_error("key", "Invalid key length"))?;
let cipher = Aes256Gcm::new(&key);
// Convert nonce
@@ -287,8 +287,7 @@ mod tests {
assert!(!serialized.is_empty());
// Test deserialization
let deserialized: DataKeyEnvelope =
serde_json::from_slice(&serialized).expect("Deserialization should succeed");
let deserialized: DataKeyEnvelope = serde_json::from_slice(&serialized).expect("Deserialization should succeed");
assert_eq!(deserialized.key_id, envelope.key_id);
assert_eq!(deserialized.master_key_id, envelope.master_key_id);
assert_eq!(deserialized.encrypted_key, envelope.encrypted_key);
@@ -307,10 +306,8 @@ mod tests {
"created_at": "2024-01-01T00:00:00Z"
}"#;
let deserialized: DataKeyEnvelope =
serde_json::from_str(old_envelope_json).expect("Should deserialize old format");
let deserialized: DataKeyEnvelope = serde_json::from_str(old_envelope_json).expect("Should deserialize old format");
assert_eq!(deserialized.key_id, "test-key-id");
assert_eq!(deserialized.master_key_id, "master-key-id");
}
}

View File

@@ -62,8 +62,8 @@ mod cache;
pub mod config;
mod encryption;
mod error;
pub mod service;
pub mod manager;
pub mod service;
pub mod service_manager;
pub mod types;
@@ -74,9 +74,9 @@ pub use api_types::{
UntagKeyRequest, UntagKeyResponse, UpdateKeyDescriptionRequest, UpdateKeyDescriptionResponse,
};
pub use config::*;
pub use service::{DataKey, ObjectEncryptionService};
pub use error::{KmsError, Result};
pub use manager::KmsManager;
pub use service::{DataKey, ObjectEncryptionService};
pub use service_manager::{
KmsServiceManager, KmsServiceStatus, get_global_encryption_service, get_global_kms_service_manager,
init_global_kms_service_manager,
@@ -152,7 +152,10 @@ mod tests {
// Start first service
let temp_dir1 = TempDir::new().expect("Failed to create temp dir");
let config1 = KmsConfig::local(temp_dir1.path().to_path_buf());
manager.configure(config1.clone()).await.expect("Configuration should succeed");
manager
.configure(config1.clone())
.await
.expect("Configuration should succeed");
manager.start().await.expect("Start should succeed");
// Verify version 1
@@ -174,7 +177,7 @@ mod tests {
// Old service reference should still be valid (Arc keeps it alive)
// New requests should get version 2
let service2 = manager.get_encryption_service().await.expect("Service should be available");
// Verify they are different instances
assert!(!Arc::ptr_eq(&service1, &service2));
@@ -191,7 +194,7 @@ mod tests {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let base_path = temp_dir.path().to_path_buf();
// Initial configuration
let config1 = KmsConfig::local(base_path.clone());
manager.configure(config1).await.expect("Configuration should succeed");
@@ -214,7 +217,7 @@ mod tests {
for handle in handles {
results.push(handle.await);
}
// All should succeed (serialized by mutex)
for result in results {
assert!(result.expect("Task should complete").is_ok());

View File

@@ -16,12 +16,15 @@
use crate::backends::{KmsBackend, local::LocalKmsBackend};
use crate::config::{BackendConfig, KmsConfig};
use crate::service::ObjectEncryptionService;
use crate::error::{KmsError, Result};
use crate::manager::KmsManager;
use crate::service::ObjectEncryptionService;
use arc_swap::ArcSwap;
use std::sync::{Arc, OnceLock, atomic::{AtomicU64, Ordering}};
use tokio::sync::{RwLock, Mutex};
use std::sync::{
Arc, OnceLock,
atomic::{AtomicU64, Ordering},
};
use tokio::sync::{Mutex, RwLock};
use tracing::{error, info, warn};
/// KMS service status
@@ -155,7 +158,7 @@ impl KmsServiceManager {
}
/// Stop KMS service
///
///
/// Note: This stops accepting new operations, but existing operations using
/// the service will continue until they complete (due to Arc reference counting).
pub async fn stop(&self) -> Result<()> {
@@ -184,18 +187,18 @@ impl KmsServiceManager {
}
/// Reconfigure and restart KMS service with zero-downtime
///
///
/// This method implements versioned service switching:
/// 1. Creates a new service version without stopping the old one
/// 2. Atomically switches to the new version
/// 3. Old operations continue using the old service (via Arc reference counting)
/// 4. New operations automatically use the new service
///
///
/// This ensures zero downtime during reconfiguration, even for long-running
/// operations like encrypting large files.
pub async fn reconfigure(&self, new_config: KmsConfig) -> Result<()> {
let _guard = self.lifecycle_mutex.lock().await;
info!("Reconfiguring KMS service (zero-downtime)");
// Configure with new config
@@ -209,8 +212,7 @@ impl KmsServiceManager {
match self.create_service_version(&new_config).await {
Ok(new_service_version) => {
// Get old version for logging (lock-free read)
let old_version = self.current_service.load().as_ref().as_ref()
.and_then(|sv| Some(sv.version));
let old_version = self.current_service.load().as_ref().as_ref().and_then(|sv| Some(sv.version));
// Atomically switch to new service version (lock-free, instant CAS operation)
// This is a true atomic operation - no waiting for locks, instant switch
@@ -226,8 +228,7 @@ impl KmsServiceManager {
if let Some(old_ver) = old_version {
info!(
"KMS service reconfigured successfully: version {} -> {} (old service will be cleaned up when operations complete)",
old_ver,
new_service_version.version
old_ver, new_service_version.version
);
} else {
info!(
@@ -248,32 +249,29 @@ impl KmsServiceManager {
}
/// Get KMS manager (if running)
///
///
/// Returns the manager from the current service version.
/// Uses lock-free atomic load for optimal performance.
pub async fn get_manager(&self) -> Option<Arc<KmsManager>> {
self.current_service.load().as_ref().as_ref()
.map(|sv| sv.manager.clone())
self.current_service.load().as_ref().as_ref().map(|sv| sv.manager.clone())
}
/// Get encryption service (if running)
///
///
/// Returns the service from the current service version.
/// Uses lock-free atomic load - no blocking, instant access.
/// This ensures new operations always use the latest service version,
/// while existing operations continue using their Arc references.
pub async fn get_encryption_service(&self) -> Option<Arc<ObjectEncryptionService>> {
self.current_service.load().as_ref().as_ref()
.map(|sv| sv.service.clone())
self.current_service.load().as_ref().as_ref().map(|sv| sv.service.clone())
}
/// Get current service version number
///
///
/// Useful for monitoring and debugging.
/// Uses lock-free atomic load.
pub async fn get_service_version(&self) -> Option<u64> {
self.current_service.load().as_ref().as_ref()
.map(|sv| sv.version)
self.current_service.load().as_ref().as_ref().map(|sv| sv.version)
}
/// Health check for the KMS service
@@ -306,12 +304,12 @@ impl KmsServiceManager {
}
/// Create a new service version from configuration
///
///
/// This creates a new backend, manager, and service, and assigns it a new version number.
async fn create_service_version(&self, config: &KmsConfig) -> Result<ServiceVersion> {
// Increment version counter
let version = self.version_counter.fetch_add(1, Ordering::Relaxed) + 1;
info!("Creating KMS service version {} with backend: {:?}", version, config.backend);
// Create backend

View File

@@ -25,6 +25,12 @@ use crate::storage::concurrency::{
use crate::storage::entity;
use crate::storage::helper::OperationHelper;
use crate::storage::options::{filter_object_metadata, get_content_sha256};
use crate::storage::sse::{
InMemoryAsyncReader, ManagedEncryptionMaterial, SsecParams, apply_ssec_decryption, apply_ssec_encryption,
create_managed_encryption_material, decrypt_managed_encryption_key, decrypt_multipart_managed_stream, derive_part_nonce,
generate_ssec_nonce, is_managed_sse, store_ssec_metadata, strip_managed_encryption_metadata, validate_ssec_params,
verify_ssec_key_match,
};
use crate::storage::{
access::{ReqInfo, authorize_request},
options::{
@@ -32,7 +38,7 @@ use crate::storage::{
get_complete_multipart_upload_opts, get_opts, parse_copy_source_range, put_opts,
},
};
use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
// base64 imports moved to sse module
use bytes::Bytes;
use chrono::{DateTime, Utc};
use datafusion::arrow::{
@@ -89,13 +95,9 @@ use rustfs_ecstore::{
},
};
use rustfs_filemeta::REPLICATE_INCOMING_DELETE;
use rustfs_filemeta::{ObjectPartInfo, RestoreStatusOps};
use rustfs_filemeta::RestoreStatusOps;
use rustfs_filemeta::{ReplicationStatusType, ReplicationType, VersionPurgeStatusType};
use rustfs_kms::{
DataKey,
service_manager::get_global_encryption_service,
types::{EncryptionMetadata, ObjectEncryptionContext},
};
// KMS imports moved to sse module
use rustfs_notify::{EventArgsBuilder, notifier_global};
use rustfs_policy::policy::{
action::{Action, S3Action},
@@ -136,10 +138,8 @@ use std::{
sync::{Arc, LazyLock},
};
use time::{OffsetDateTime, format_description::well_known::Rfc3339};
use tokio::{
io::{AsyncRead, AsyncSeek},
sync::mpsc,
};
use tokio::sync::mpsc;
// AsyncRead and AsyncSeek moved to sse module
use tokio_stream::wrappers::ReceiverStream;
use tokio_tar::Archive;
use tokio_util::io::{ReaderStream, StreamReader};
@@ -265,118 +265,7 @@ pub struct FS {
// pub store: ECStore,
}
struct ManagedEncryptionMaterial {
data_key: DataKey,
headers: HashMap<String, String>,
kms_key_id: String,
}
async fn create_managed_encryption_material(
bucket: &str,
key: &str,
algorithm: &ServerSideEncryption,
kms_key_id: Option<String>,
original_size: i64,
) -> Result<ManagedEncryptionMaterial, ApiError> {
let Some(service) = get_global_encryption_service().await else {
return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized")));
};
if !is_managed_sse(algorithm) {
return Err(ApiError::from(StorageError::other(format!(
"Unsupported server-side encryption algorithm: {}",
algorithm.as_str()
))));
}
let algorithm_str = algorithm.as_str();
let mut context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string());
if original_size >= 0 {
context = context.with_size(original_size as u64);
}
let mut kms_key_candidate = kms_key_id;
if kms_key_candidate.is_none() {
kms_key_candidate = service.get_default_key_id().cloned();
}
let kms_key_to_use = kms_key_candidate
.clone()
.ok_or_else(|| ApiError::from(StorageError::other("No KMS key available for managed server-side encryption")))?;
let (data_key, encrypted_data_key) = service
.create_data_key(&kms_key_candidate, &context)
.await
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to create data key: {e}"))))?;
let metadata = EncryptionMetadata {
algorithm: algorithm_str.to_string(),
key_id: kms_key_to_use.clone(),
key_version: 1,
iv: data_key.nonce.to_vec(),
tag: None,
encryption_context: context.encryption_context.clone(),
encrypted_at: Utc::now(),
original_size: if original_size >= 0 { original_size as u64 } else { 0 },
encrypted_data_key,
};
let mut headers = service.metadata_to_headers(&metadata);
headers.insert("x-rustfs-encryption-original-size".to_string(), metadata.original_size.to_string());
Ok(ManagedEncryptionMaterial {
data_key,
headers,
kms_key_id: kms_key_to_use,
})
}
async fn decrypt_managed_encryption_key(
bucket: &str,
key: &str,
metadata: &HashMap<String, String>,
) -> Result<Option<([u8; 32], [u8; 12], Option<i64>)>, ApiError> {
if !metadata.contains_key("x-rustfs-encryption-key") {
return Ok(None);
}
let Some(service) = get_global_encryption_service().await else {
return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized")));
};
let parsed = service
.headers_to_metadata(metadata)
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to parse encryption metadata: {e}"))))?;
if parsed.iv.len() != 12 {
return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes")));
}
let context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string());
let data_key = service
.decrypt_data_key(&parsed.encrypted_data_key, &context)
.await
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to decrypt data key: {e}"))))?;
let key_bytes = data_key.plaintext_key;
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&parsed.iv[..12]);
let original_size = metadata
.get("x-rustfs-encryption-original-size")
.and_then(|s| s.parse::<i64>().ok());
Ok(Some((key_bytes, nonce, original_size)))
}
fn derive_part_nonce(base: [u8; 12], part_number: usize) -> [u8; 12] {
let mut nonce = base;
let current = u32::from_be_bytes([nonce[8], nonce[9], nonce[10], nonce[11]]);
let incremented = current.wrapping_add(part_number as u32);
nonce[8..12].copy_from_slice(&incremented.to_be_bytes());
nonce
}
// SSE-related types and functions moved to crate::storage::sse module
#[derive(Debug, Default, serde::Deserialize)]
struct ListObjectUnorderedQuery {
@@ -384,98 +273,8 @@ struct ListObjectUnorderedQuery {
allow_unordered: Option<String>,
}
struct InMemoryAsyncReader {
cursor: std::io::Cursor<Vec<u8>>,
}
impl InMemoryAsyncReader {
fn new(data: Vec<u8>) -> Self {
Self {
cursor: std::io::Cursor::new(data),
}
}
}
impl AsyncRead for InMemoryAsyncReader {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let unfilled = buf.initialize_unfilled();
let bytes_read = std::io::Read::read(&mut self.cursor, unfilled)?;
buf.advance(bytes_read);
std::task::Poll::Ready(Ok(()))
}
}
impl AsyncSeek for InMemoryAsyncReader {
fn start_seek(mut self: std::pin::Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> {
// std::io::Cursor natively supports negative SeekCurrent offsets
// It will automatically handle validation and return an error if the final position would be negative
std::io::Seek::seek(&mut self.cursor, position)?;
Ok(())
}
fn poll_complete(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<u64>> {
std::task::Poll::Ready(Ok(self.cursor.position()))
}
}
async fn decrypt_multipart_managed_stream(
mut encrypted_stream: Box<dyn AsyncRead + Unpin + Send + Sync>,
parts: &[ObjectPartInfo],
key_bytes: [u8; 32],
base_nonce: [u8; 12],
) -> Result<(Box<dyn Reader>, i64), StorageError> {
let total_plain_capacity: usize = parts.iter().map(|part| part.actual_size.max(0) as usize).sum();
let mut plaintext = Vec::with_capacity(total_plain_capacity);
for part in parts {
if part.size == 0 {
continue;
}
let mut encrypted_part = vec![0u8; part.size];
tokio::io::AsyncReadExt::read_exact(&mut encrypted_stream, &mut encrypted_part)
.await
.map_err(|e| StorageError::other(format!("failed to read encrypted multipart segment {}: {}", part.number, e)))?;
let part_nonce = derive_part_nonce(base_nonce, part.number);
let cursor = std::io::Cursor::new(encrypted_part);
let mut decrypt_reader = DecryptReader::new(WarpReader::new(cursor), key_bytes, part_nonce);
tokio::io::AsyncReadExt::read_to_end(&mut decrypt_reader, &mut plaintext)
.await
.map_err(|e| StorageError::other(format!("failed to decrypt multipart segment {}: {}", part.number, e)))?;
}
let total_plain_size = plaintext.len() as i64;
let reader = Box::new(WarpReader::new(InMemoryAsyncReader::new(plaintext))) as Box<dyn Reader>;
Ok((reader, total_plain_size))
}
fn strip_managed_encryption_metadata(metadata: &mut HashMap<String, String>) {
const KEYS: [&str; 7] = [
"x-amz-server-side-encryption",
"x-amz-server-side-encryption-aws-kms-key-id",
"x-rustfs-encryption-iv",
"x-rustfs-encryption-tag",
"x-rustfs-encryption-key",
"x-rustfs-encryption-context",
"x-rustfs-encryption-original-size",
];
for key in KEYS.iter() {
metadata.remove(*key);
}
}
fn is_managed_sse(algorithm: &ServerSideEncryption) -> bool {
matches!(algorithm.as_str(), "AES256" | "aws:kms")
}
// InMemoryAsyncReader, decrypt_multipart_managed_stream, strip_managed_encryption_metadata,
// and is_managed_sse moved to crate::storage::sse module
/// Validate object key for control characters and log special characters
///
@@ -1284,39 +1083,19 @@ impl S3 for FS {
// Apply SSE-C encryption if customer-provided key is specified
if let (Some(sse_alg), Some(sse_key), Some(sse_md5)) = (&sse_customer_algorithm, &sse_customer_key, &sse_customer_key_md5)
&& sse_alg.as_str() == "AES256"
{
let key_bytes = BASE64_STANDARD.decode(sse_key.as_str()).map_err(|e| {
error!("Failed to decode SSE-C key: {}", e);
ApiError::from(StorageError::other("Invalid SSE-C key"))
})?;
let params = SsecParams {
algorithm: sse_alg.as_str().to_string(),
key: sse_key.clone(),
key_md5: sse_md5.clone(),
};
if key_bytes.len() != 32 {
return Err(ApiError::from(StorageError::other("SSE-C key must be 32 bytes")).into());
}
let validated = validate_ssec_params(&params)?;
let encrypted_reader = apply_ssec_encryption(reader, &validated, &bucket, &key);
reader = HashReader::new(encrypted_reader, -1, actual_size, None, None, false).map_err(ApiError::from)?;
let computed_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
if computed_md5 != sse_md5.as_str() {
return Err(ApiError::from(StorageError::other("SSE-C key MD5 mismatch")).into());
}
// Store original size before encryption
src_info
.user_defined
.insert("x-amz-server-side-encryption-customer-original-size".to_string(), actual_size.to_string());
// SAFETY: The length of `key_bytes` is checked to be 32 bytes above,
// so this conversion cannot fail.
let key_array: [u8; 32] = key_bytes.try_into().expect("key length already checked");
// Generate deterministic nonce from bucket-key
let nonce_source = format!("{bucket}-{key}");
let nonce_hash = md5::compute(nonce_source.as_bytes());
let nonce: [u8; 12] = nonce_hash.0[..12]
.try_into()
.expect("MD5 hash is always 16 bytes; taking first 12 bytes for nonce is safe");
let encrypt_reader = EncryptReader::new(reader, key_array, nonce);
reader = HashReader::new(Box::new(encrypt_reader), -1, actual_size, None, None, false).map_err(ApiError::from)?;
// Store SSE-C metadata for GET responses
store_ssec_metadata(&mut src_info.user_defined, &validated, actual_size);
}
src_info.put_object_reader = Some(PutObjReader::new(reader));
@@ -1327,19 +1106,6 @@ impl S3 for FS {
src_info.user_defined.insert(k, v);
}
// Store SSE-C metadata for GET responses
if let Some(ref sse_alg) = sse_customer_algorithm {
src_info.user_defined.insert(
"x-amz-server-side-encryption-customer-algorithm".to_string(),
sse_alg.as_str().to_string(),
);
}
if let Some(ref sse_md5) = sse_customer_key_md5 {
src_info
.user_defined
.insert("x-amz-server-side-encryption-customer-key-md5".to_string(), sse_md5.clone());
}
// check quota for copy operation
if let Some(metadata_sys) = rustfs_ecstore::bucket::metadata_sys::GLOBAL_BucketMetadataSys.get() {
let quota_checker = QuotaChecker::new(metadata_sys.clone());
@@ -2462,51 +2228,19 @@ impl S3 for FS {
// TODO: Implement proper multipart SSE-C encryption/decryption
} else {
// Verify that the provided key MD5 matches the stored MD5
if let Some(stored_md5) = stored_sse_key_md5 {
debug!("SSE-C MD5 comparison: provided='{}', stored='{}'", sse_key_md5_provided, stored_md5);
if sse_key_md5_provided != stored_md5 {
error!("SSE-C key MD5 mismatch: provided='{}', stored='{}'", sse_key_md5_provided, stored_md5);
return Err(
ApiError::from(StorageError::other("SSE-C key does not match object encryption key")).into()
);
}
} else {
return Err(ApiError::from(StorageError::other(
"Object encrypted with SSE-C but stored key MD5 not found",
))
.into());
}
verify_ssec_key_match(sse_key_md5_provided, stored_sse_key_md5)?;
// Decode the base64 key
let key_bytes = BASE64_STANDARD
.decode(sse_key)
.map_err(|e| ApiError::from(StorageError::other(format!("Invalid SSE-C key: {e}"))))?;
// Verify key length (should be 32 bytes for AES-256)
if key_bytes.len() != 32 {
return Err(ApiError::from(StorageError::other("SSE-C key must be 32 bytes")).into());
}
// Convert Vec<u8> to [u8; 32]
let mut key_array = [0u8; 32];
key_array.copy_from_slice(&key_bytes[..32]);
// Verify MD5 hash of the key matches what the client claims
let computed_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
if computed_md5 != *sse_key_md5_provided {
return Err(ApiError::from(StorageError::other("SSE-C key MD5 mismatch")).into());
}
// Generate the same deterministic nonce from object key
let mut nonce = [0u8; 12];
let nonce_source = format!("{bucket}-{key}");
let nonce_hash = md5::compute(nonce_source.as_bytes());
nonce.copy_from_slice(&nonce_hash.0[..12]);
// Validate and prepare SSE-C decryption parameters
let params = SsecParams {
algorithm: "AES256".to_string(),
key: sse_key.to_string(),
key_md5: sse_key_md5_provided.to_string(),
};
let validated = validate_ssec_params(&params)?;
// Apply decryption
// We need to wrap the stream in a Reader first since DecryptReader expects a Reader
let warp_reader = WarpReader::new(final_stream);
let decrypt_reader = DecryptReader::new(warp_reader, key_array, nonce);
let decrypt_reader = apply_ssec_decryption(warp_reader, &validated, &bucket, &key);
final_stream = Box::new(decrypt_reader);
}
} else {
@@ -3715,45 +3449,24 @@ impl S3 for FS {
}
// Apply SSE-C encryption if customer provided key
if let (Some(_), Some(sse_key), Some(sse_key_md5_provided)) =
if let (Some(sse_alg), Some(sse_key), Some(sse_key_md5)) =
(&sse_customer_algorithm, &sse_customer_key, &sse_customer_key_md5)
{
// Decode the base64 key
let key_bytes = BASE64_STANDARD
.decode(sse_key)
.map_err(|e| ApiError::from(StorageError::other(format!("Invalid SSE-C key: {e}"))))?;
let params = SsecParams {
algorithm: sse_alg.as_str().to_string(),
key: sse_key.clone(),
key_md5: sse_key_md5.clone(),
};
// Verify key length (should be 32 bytes for AES-256)
if key_bytes.len() != 32 {
return Err(ApiError::from(StorageError::other("SSE-C key must be 32 bytes")).into());
}
// Convert Vec<u8> to [u8; 32]
let mut key_array = [0u8; 32];
key_array.copy_from_slice(&key_bytes[..32]);
// Verify MD5 hash of the key matches what the client claims
let computed_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
if computed_md5 != *sse_key_md5_provided {
return Err(ApiError::from(StorageError::other("SSE-C key MD5 mismatch")).into());
}
let validated = validate_ssec_params(&params)?;
// Store original size for later retrieval during decryption
let original_size = if size >= 0 { size } else { actual_size };
metadata.insert(
"x-amz-server-side-encryption-customer-original-size".to_string(),
original_size.to_string(),
);
// Generate a deterministic nonce from object key for consistency
let mut nonce = [0u8; 12];
let nonce_source = format!("{bucket}-{key}");
let nonce_hash = md5::compute(nonce_source.as_bytes());
nonce.copy_from_slice(&nonce_hash.0[..12]);
store_ssec_metadata(&mut metadata, &validated, original_size);
// Apply encryption
let encrypt_reader = EncryptReader::new(reader, key_array, nonce);
reader = HashReader::new(Box::new(encrypt_reader), -1, actual_size, None, None, false).map_err(ApiError::from)?;
let encrypted_reader = apply_ssec_encryption(reader, &validated, &bucket, &key);
reader = HashReader::new(encrypted_reader, -1, actual_size, None, None, false).map_err(ApiError::from)?;
}
// Apply managed SSE (SSE-S3 or SSE-KMS) when requested
@@ -4052,9 +3765,9 @@ impl S3 for FS {
upload_id,
part_number,
content_length,
sse_customer_algorithm: _sse_customer_algorithm,
sse_customer_key: _sse_customer_key,
sse_customer_key_md5: _sse_customer_key_md5,
sse_customer_algorithm,
sse_customer_key,
sse_customer_key_md5,
// content_md5,
..
} = input;
@@ -4151,45 +3864,32 @@ impl S3 for FS {
let actual_size = size;
// TODO: Apply SSE-C encryption for upload_part if needed
// Temporarily commented out to debug multipart issues
/*
// Apply SSE-C encryption if customer provided key before any other processing
if let (Some(_), Some(sse_key), Some(sse_key_md5_provided)) =
(&_sse_customer_algorithm, &_sse_customer_key, &_sse_customer_key_md5) {
// Apply SSE-C encryption for upload_part if customer provided key
if let (Some(sse_alg), Some(sse_key), Some(sse_key_md5)) =
(&sse_customer_algorithm, &sse_customer_key, &sse_customer_key_md5)
{
let params = SsecParams {
algorithm: sse_alg.as_str().to_string(),
key: sse_key.clone(),
key_md5: sse_key_md5.clone(),
};
// Decode the base64 key
let key_bytes = BASE64_STANDARD.decode(sse_key)
.map_err(|e| ApiError::from(StorageError::other(format!("Invalid SSE-C key: {}", e))))?;
let validated = validate_ssec_params(&params)?;
// Verify key length (should be 32 bytes for AES-256)
if key_bytes.len() != 32 {
return Err(ApiError::from(StorageError::other("SSE-C key must be 32 bytes")).into());
}
// For multipart upload, derive a unique nonce for each part
// This ensures each part has a different nonce while maintaining determinism
let base_nonce = generate_ssec_nonce(&bucket, &key);
let part_nonce = derive_part_nonce(base_nonce, part_id);
// Convert Vec<u8> to [u8; 32]
let mut key_array = [0u8; 32];
key_array.copy_from_slice(&key_bytes[..32]);
// Apply encryption with part-specific nonce
let encrypted_reader = EncryptReader::new(reader, validated.key_bytes, part_nonce);
reader = Box::new(encrypted_reader);
// Verify MD5 hash of the key matches what the client claims
let computed_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
if computed_md5 != *sse_key_md5_provided {
return Err(ApiError::from(StorageError::other("SSE-C key MD5 mismatch")).into());
}
// Generate a deterministic nonce from object key for consistency
let mut nonce = [0u8; 12];
let nonce_source = format!("{}-{}", bucket, key);
let nonce_hash = md5::compute(nonce_source.as_bytes());
nonce.copy_from_slice(&nonce_hash.0[..12]);
// Apply encryption - this will change the size so we need to handle it
let encrypt_reader = EncryptReader::new(reader, key_array, nonce);
reader = Box::new(encrypt_reader);
// When encrypting, size becomes unknown since encryption adds authentication tags
size = -1;
debug!("Applied SSE-C encryption to part {} with derived nonce for {}/{}", part_id, bucket, key);
}
*/
let mut md5hex = if let Some(base64_md5) = input.content_md5 {
let md5 = base64_simd::STANDARD

View File

@@ -18,6 +18,7 @@ pub mod ecfs;
pub(crate) mod entity;
pub(crate) mod helper;
pub mod options;
pub mod sse;
pub mod tonic_service;
#[cfg(test)]

650
rustfs/src/storage/sse.rs Normal file
View File

@@ -0,0 +1,650 @@
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Server-Side Encryption (SSE) utilities
//!
//! This module provides reusable components for handling S3 Server-Side Encryption:
//! - SSE-S3 (AES256): Server-managed encryption with S3-managed keys
//! - SSE-KMS (aws:kms): Server-managed encryption with KMS-managed keys
//! - SSE-C (AES256): Customer-provided encryption keys
//!
//! ## Architecture
//!
//! ### Managed SSE (SSE-S3 / SSE-KMS)
//! - Keys are managed by the server-side KMS service
//! - Data keys are generated and encrypted by KMS
//! - Encryption metadata is stored in object metadata
//!
//! ### Customer-Provided Keys (SSE-C)
//! - Keys are provided by the client on every request
//! - Server validates key using MD5 hash
//! - Keys are NEVER stored on the server
//!
//! ## Usage Example
//!
//! ```rust,ignore
//! // Apply managed SSE encryption
//! if let Some(material) = apply_managed_sse(
//! &bucket, &key, &sse_algorithm, kms_key_id, actual_size
//! ).await? {
//! reader = material.wrap_encrypt_reader(reader)?;
//! metadata.extend(material.headers);
//! }
//!
//! // Apply SSE-C encryption
//! if let Some(params) = sse_customer_params {
//! let validated = validate_ssec_params(&params)?;
//! reader = apply_ssec_encryption(reader, &validated, &bucket, &key)?;
//! }
//! ```
use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
use chrono::Utc;
use rustfs_ecstore::error::StorageError;
use rustfs_filemeta::ObjectPartInfo;
use rustfs_kms::{
DataKey,
service_manager::get_global_encryption_service,
types::{EncryptionMetadata, ObjectEncryptionContext},
};
use rustfs_rio::{DecryptReader, EncryptReader, Reader, WarpReader};
use s3s::dto::ServerSideEncryption;
use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncSeek};
use tracing::error;
use crate::error::ApiError;
// ============================================================================
// Public Types
// ============================================================================
/// Material for managed server-side encryption (SSE-S3/SSE-KMS)
#[derive(Debug, Clone)]
pub struct ManagedEncryptionMaterial {
/// Data encryption key
pub data_key: DataKey,
/// Metadata headers to store with the object
pub headers: HashMap<String, String>,
/// KMS key ID used for encryption
pub kms_key_id: String,
}
/// Validated SSE-C parameters
#[derive(Debug, Clone)]
pub struct ValidatedSsecParams {
/// Encryption algorithm (always "AES256" for SSE-C)
pub algorithm: String,
/// Decoded encryption key bytes (32 bytes for AES-256)
pub key_bytes: [u8; 32],
/// Base64-encoded MD5 of the key
pub key_md5: String,
}
/// SSE-C parameters from client request
#[derive(Debug, Clone)]
pub struct SsecParams {
/// Encryption algorithm
pub algorithm: String,
/// Base64-encoded encryption key
pub key: String,
/// Base64-encoded MD5 of the key
pub key_md5: String,
}
// ============================================================================
// Managed SSE Functions (SSE-S3 / SSE-KMS)
// ============================================================================
/// Check if the algorithm is a managed SSE type (SSE-S3 or SSE-KMS)
#[inline]
pub fn is_managed_sse(algorithm: &ServerSideEncryption) -> bool {
matches!(algorithm.as_str(), "AES256" | "aws:kms")
}
/// Create managed encryption material for SSE-S3 or SSE-KMS
///
/// This function:
/// 1. Validates the encryption algorithm
/// 2. Creates an encryption context
/// 3. Generates a data key via KMS
/// 4. Prepares metadata headers for storage
///
/// # Arguments
/// * `bucket` - Bucket name
/// * `key` - Object key
/// * `algorithm` - Encryption algorithm (AES256 or aws:kms)
/// * `kms_key_id` - Optional KMS key ID (uses default if None)
/// * `original_size` - Original object size before encryption
///
/// # Returns
/// `ManagedEncryptionMaterial` containing data key, headers, and key ID
pub async fn create_managed_encryption_material(
bucket: &str,
key: &str,
algorithm: &ServerSideEncryption,
kms_key_id: Option<String>,
original_size: i64,
) -> Result<ManagedEncryptionMaterial, ApiError> {
let Some(service) = get_global_encryption_service().await else {
return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized")));
};
if !is_managed_sse(algorithm) {
return Err(ApiError::from(StorageError::other(format!(
"Unsupported server-side encryption algorithm: {}",
algorithm.as_str()
))));
}
let algorithm_str = algorithm.as_str();
let mut context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string());
if original_size >= 0 {
context = context.with_size(original_size as u64);
}
let mut kms_key_candidate = kms_key_id;
if kms_key_candidate.is_none() {
kms_key_candidate = service.get_default_key_id().cloned();
}
let kms_key_to_use = kms_key_candidate
.clone()
.ok_or_else(|| ApiError::from(StorageError::other("No KMS key available for managed server-side encryption")))?;
let (data_key, encrypted_data_key) = service
.create_data_key(&kms_key_candidate, &context)
.await
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to create data key: {e}"))))?;
let metadata = EncryptionMetadata {
algorithm: algorithm_str.to_string(),
key_id: kms_key_to_use.clone(),
key_version: 1,
iv: data_key.nonce.to_vec(),
tag: None,
encryption_context: context.encryption_context.clone(),
encrypted_at: Utc::now(),
original_size: if original_size >= 0 { original_size as u64 } else { 0 },
encrypted_data_key,
};
let mut headers = service.metadata_to_headers(&metadata);
headers.insert("x-rustfs-encryption-original-size".to_string(), metadata.original_size.to_string());
Ok(ManagedEncryptionMaterial {
data_key,
headers,
kms_key_id: kms_key_to_use,
})
}
/// Decrypt managed encryption key from object metadata
///
/// This function:
/// 1. Checks if object has managed encryption metadata
/// 2. Parses encryption metadata from headers
/// 3. Decrypts the data key using KMS
///
/// # Arguments
/// * `bucket` - Bucket name
/// * `key` - Object key
/// * `metadata` - Object metadata containing encryption headers
///
/// # Returns
/// `Some((key_bytes, nonce, original_size))` if object is encrypted, `None` otherwise
pub async fn decrypt_managed_encryption_key(
bucket: &str,
key: &str,
metadata: &HashMap<String, String>,
) -> Result<Option<([u8; 32], [u8; 12], Option<i64>)>, ApiError> {
if !metadata.contains_key("x-rustfs-encryption-key") {
return Ok(None);
}
let Some(service) = get_global_encryption_service().await else {
return Err(ApiError::from(StorageError::other("KMS encryption service is not initialized")));
};
let parsed = service
.headers_to_metadata(metadata)
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to parse encryption metadata: {e}"))))?;
if parsed.iv.len() != 12 {
return Err(ApiError::from(StorageError::other("Invalid encryption nonce length; expected 12 bytes")));
}
let context = ObjectEncryptionContext::new(bucket.to_string(), key.to_string());
let data_key = service
.decrypt_data_key(&parsed.encrypted_data_key, &context)
.await
.map_err(|e| ApiError::from(StorageError::other(format!("Failed to decrypt data key: {e}"))))?;
let key_bytes = data_key.plaintext_key;
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&parsed.iv[..12]);
let original_size = metadata
.get("x-rustfs-encryption-original-size")
.and_then(|s| s.parse::<i64>().ok());
Ok(Some((key_bytes, nonce, original_size)))
}
/// Strip managed encryption metadata from object metadata
///
/// Removes all managed SSE-related headers before returning object metadata to client.
/// This is necessary because encryption is transparent to S3 clients.
pub fn strip_managed_encryption_metadata(metadata: &mut HashMap<String, String>) {
const KEYS: [&str; 7] = [
"x-amz-server-side-encryption",
"x-amz-server-side-encryption-aws-kms-key-id",
"x-rustfs-encryption-iv",
"x-rustfs-encryption-tag",
"x-rustfs-encryption-key",
"x-rustfs-encryption-context",
"x-rustfs-encryption-original-size",
];
for key in KEYS.iter() {
metadata.remove(*key);
}
}
// ============================================================================
// Multipart Encryption Support
// ============================================================================
/// Derive a unique nonce for each part in a multipart upload
///
/// Uses the base nonce and increments the counter portion by part number.
/// This ensures each part has a unique nonce while maintaining determinism.
pub fn derive_part_nonce(base: [u8; 12], part_number: usize) -> [u8; 12] {
let mut nonce = base;
let current = u32::from_be_bytes([nonce[8], nonce[9], nonce[10], nonce[11]]);
let incremented = current.wrapping_add(part_number as u32);
nonce[8..12].copy_from_slice(&incremented.to_be_bytes());
nonce
}
/// In-memory async reader for decrypted multipart data
pub(crate) struct InMemoryAsyncReader {
cursor: std::io::Cursor<Vec<u8>>,
}
impl InMemoryAsyncReader {
pub(crate) fn new(data: Vec<u8>) -> Self {
Self {
cursor: std::io::Cursor::new(data),
}
}
}
impl AsyncRead for InMemoryAsyncReader {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let unfilled = buf.initialize_unfilled();
let bytes_read = std::io::Read::read(&mut self.cursor, unfilled)?;
buf.advance(bytes_read);
std::task::Poll::Ready(Ok(()))
}
}
impl AsyncSeek for InMemoryAsyncReader {
fn start_seek(mut self: std::pin::Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> {
std::io::Seek::seek(&mut self.cursor, position)?;
Ok(())
}
fn poll_complete(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<u64>> {
std::task::Poll::Ready(Ok(self.cursor.position()))
}
}
/// Decrypt a multipart upload stream with managed SSE encryption
///
/// This function:
/// 1. Reads each encrypted part from the stream
/// 2. Derives a unique nonce for each part
/// 3. Decrypts each part individually
/// 4. Concatenates all plaintext parts into a single buffer
///
/// # Arguments
/// * `encrypted_stream` - Stream containing encrypted multipart data
/// * `parts` - Part info containing sizes and part numbers
/// * `key_bytes` - Decryption key
/// * `base_nonce` - Base nonce (unique nonce derived per part)
///
/// # Returns
/// Tuple of (decrypted_reader, total_plaintext_size)
pub async fn decrypt_multipart_managed_stream(
mut encrypted_stream: Box<dyn AsyncRead + Unpin + Send + Sync>,
parts: &[ObjectPartInfo],
key_bytes: [u8; 32],
base_nonce: [u8; 12],
) -> Result<(Box<dyn Reader>, i64), StorageError> {
let total_plain_capacity: usize = parts.iter().map(|part| part.actual_size.max(0) as usize).sum();
let mut plaintext = Vec::with_capacity(total_plain_capacity);
for part in parts {
if part.size == 0 {
continue;
}
let mut encrypted_part = vec![0u8; part.size];
tokio::io::AsyncReadExt::read_exact(&mut encrypted_stream, &mut encrypted_part)
.await
.map_err(|e| StorageError::other(format!("failed to read encrypted multipart segment {}: {}", part.number, e)))?;
let part_nonce = derive_part_nonce(base_nonce, part.number);
let cursor = std::io::Cursor::new(encrypted_part);
let mut decrypt_reader = DecryptReader::new(WarpReader::new(cursor), key_bytes, part_nonce);
tokio::io::AsyncReadExt::read_to_end(&mut decrypt_reader, &mut plaintext)
.await
.map_err(|e| StorageError::other(format!("failed to decrypt multipart segment {}: {}", part.number, e)))?;
}
let total_plain_size = plaintext.len() as i64;
let reader = Box::new(WarpReader::new(InMemoryAsyncReader::new(plaintext))) as Box<dyn Reader>;
Ok((reader, total_plain_size))
}
// ============================================================================
// Customer-Provided Key (SSE-C) Functions
// ============================================================================
/// Validate SSE-C parameters from client request
///
/// This function:
/// 1. Validates the algorithm is AES256
/// 2. Decodes the Base64-encoded key
/// 3. Validates key length is 32 bytes
/// 4. Verifies MD5 hash matches
///
/// # Arguments
/// * `params` - SSE-C parameters from client
///
/// # Returns
/// `ValidatedSsecParams` with decoded key bytes
pub fn validate_ssec_params(params: &SsecParams) -> Result<ValidatedSsecParams, ApiError> {
// Validate algorithm
if params.algorithm != "AES256" {
return Err(ApiError::from(StorageError::other(format!(
"Unsupported SSE-C algorithm: {}. Only AES256 is supported",
params.algorithm
))));
}
// Decode Base64 key
let key_bytes = BASE64_STANDARD.decode(&params.key).map_err(|e| {
error!("Failed to decode SSE-C key: {}", e);
ApiError::from(StorageError::other("Invalid SSE-C key: not valid Base64"))
})?;
// Validate key length (must be 32 bytes for AES-256)
if key_bytes.len() != 32 {
return Err(ApiError::from(StorageError::other(format!(
"SSE-C key must be 32 bytes (256 bits), got {} bytes",
key_bytes.len()
))));
}
// Verify MD5 hash
let computed_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
if computed_md5 != params.key_md5 {
error!("SSE-C key MD5 mismatch: expected '{}', got '{}'", params.key_md5, computed_md5);
return Err(ApiError::from(StorageError::other("SSE-C key MD5 mismatch")));
}
// SAFETY: We validated the length is exactly 32 bytes above
let key_array: [u8; 32] = key_bytes.try_into().expect("key length already validated to be 32 bytes");
Ok(ValidatedSsecParams {
algorithm: params.algorithm.clone(),
key_bytes: key_array,
key_md5: params.key_md5.clone(),
})
}
/// Generate deterministic nonce for SSE-C encryption
///
/// The nonce is derived from the bucket and key to ensure:
/// 1. Same object always gets the same nonce (required for SSE-C)
/// 2. Different objects get different nonces
pub fn generate_ssec_nonce(bucket: &str, key: &str) -> [u8; 12] {
let nonce_source = format!("{bucket}-{key}");
let nonce_hash = md5::compute(nonce_source.as_bytes());
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&nonce_hash.0[..12]);
nonce
}
/// Apply SSE-C encryption to a reader
///
/// # Arguments
/// * `reader` - Input reader to encrypt
/// * `validated` - Validated SSE-C parameters
/// * `bucket` - Bucket name (for nonce generation)
/// * `key` - Object key (for nonce generation)
///
/// # Returns
/// Encrypted reader wrapped in Box
pub fn apply_ssec_encryption<R>(reader: R, validated: &ValidatedSsecParams, bucket: &str, key: &str) -> Box<EncryptReader<R>>
where
R: Reader + 'static,
{
let nonce = generate_ssec_nonce(bucket, key);
Box::new(EncryptReader::new(reader, validated.key_bytes, nonce))
}
/// Apply SSE-C decryption to a reader
///
/// # Arguments
/// * `reader` - Encrypted reader to decrypt
/// * `validated` - Validated SSE-C parameters
/// * `bucket` - Bucket name (for nonce generation)
/// * `key` - Object key (for nonce generation)
///
/// # Returns
/// Decrypted reader wrapped in Box
pub fn apply_ssec_decryption<R>(reader: R, validated: &ValidatedSsecParams, bucket: &str, key: &str) -> Box<DecryptReader<R>>
where
R: Reader + 'static,
{
let nonce = generate_ssec_nonce(bucket, key);
Box::new(DecryptReader::new(reader, validated.key_bytes, nonce))
}
/// Store SSE-C metadata in object metadata
///
/// Stores the algorithm and key MD5 for later validation during GetObject.
/// Note: The encryption key itself is NEVER stored.
pub fn store_ssec_metadata(metadata: &mut HashMap<String, String>, validated: &ValidatedSsecParams, original_size: i64) {
metadata.insert("x-amz-server-side-encryption-customer-algorithm".to_string(), validated.algorithm.clone());
metadata.insert("x-amz-server-side-encryption-customer-key-md5".to_string(), validated.key_md5.clone());
metadata.insert(
"x-amz-server-side-encryption-customer-original-size".to_string(),
original_size.to_string(),
);
}
/// Verify SSE-C key matches the stored metadata
///
/// Used during GetObject to ensure the client provided the correct key.
pub fn verify_ssec_key_match(provided_md5: &str, stored_md5: Option<&String>) -> Result<(), ApiError> {
let Some(stored) = stored_md5 else {
return Err(ApiError::from(StorageError::other(
"Object encrypted with SSE-C but stored key MD5 not found",
)));
};
if provided_md5 != stored {
error!("SSE-C key MD5 mismatch: provided='{}', stored='{}'", provided_md5, stored);
return Err(ApiError::from(StorageError::other("SSE-C key does not match object encryption key")));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_managed_sse() {
assert!(is_managed_sse(&ServerSideEncryption::from_static("AES256")));
assert!(is_managed_sse(&ServerSideEncryption::from_static("aws:kms")));
assert!(!is_managed_sse(&ServerSideEncryption::from_static("invalid")));
}
#[test]
fn test_derive_part_nonce() {
let base_nonce: [u8; 12] = [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 1];
let part1_nonce = derive_part_nonce(base_nonce, 1);
let part2_nonce = derive_part_nonce(base_nonce, 2);
// First 8 bytes should be the same
assert_eq!(&part1_nonce[..8], &base_nonce[..8]);
assert_eq!(&part2_nonce[..8], &base_nonce[..8]);
// Last 4 bytes should be different (counter)
assert_ne!(&part1_nonce[8..], &part2_nonce[8..]);
}
#[test]
fn test_generate_ssec_nonce() {
let nonce1 = generate_ssec_nonce("bucket1", "key1");
let nonce2 = generate_ssec_nonce("bucket1", "key1");
let nonce3 = generate_ssec_nonce("bucket1", "key2");
// Same bucket/key should generate same nonce
assert_eq!(nonce1, nonce2);
// Different key should generate different nonce
assert_ne!(nonce1, nonce3);
// Nonce should be 12 bytes
assert_eq!(nonce1.len(), 12);
}
#[test]
fn test_validate_ssec_params_success() {
// Generate a valid 32-byte key
let key_bytes = [42u8; 32];
let key_b64 = BASE64_STANDARD.encode(key_bytes);
let key_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
let params = SsecParams {
algorithm: "AES256".to_string(),
key: key_b64,
key_md5,
};
let result = validate_ssec_params(&params);
assert!(result.is_ok());
let validated = result.unwrap();
assert_eq!(validated.algorithm, "AES256");
assert_eq!(validated.key_bytes, key_bytes);
}
#[test]
fn test_validate_ssec_params_wrong_algorithm() {
let key_bytes = [42u8; 32];
let key_b64 = BASE64_STANDARD.encode(key_bytes);
let key_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
let params = SsecParams {
algorithm: "AES128".to_string(),
key: key_b64,
key_md5,
};
let result = validate_ssec_params(&params);
assert!(result.is_err());
}
#[test]
fn test_validate_ssec_params_wrong_key_length() {
let key_bytes = [42u8; 16]; // Wrong length
let key_b64 = BASE64_STANDARD.encode(key_bytes);
let key_md5 = BASE64_STANDARD.encode(md5::compute(&key_bytes).0);
let params = SsecParams {
algorithm: "AES256".to_string(),
key: key_b64,
key_md5,
};
let result = validate_ssec_params(&params);
assert!(result.is_err());
}
#[test]
fn test_validate_ssec_params_wrong_md5() {
let key_bytes = [42u8; 32];
let key_b64 = BASE64_STANDARD.encode(key_bytes);
let wrong_md5 = "wrong_md5_hash_here==";
let params = SsecParams {
algorithm: "AES256".to_string(),
key: key_b64,
key_md5: wrong_md5.to_string(),
};
let result = validate_ssec_params(&params);
assert!(result.is_err());
}
#[test]
fn test_strip_managed_encryption_metadata() {
let mut metadata = HashMap::new();
metadata.insert("x-amz-server-side-encryption".to_string(), "AES256".to_string());
metadata.insert("x-rustfs-encryption-key".to_string(), "encrypted_key".to_string());
metadata.insert("content-type".to_string(), "text/plain".to_string());
strip_managed_encryption_metadata(&mut metadata);
assert!(!metadata.contains_key("x-amz-server-side-encryption"));
assert!(!metadata.contains_key("x-rustfs-encryption-key"));
assert!(metadata.contains_key("content-type")); // Should not be removed
}
#[test]
fn test_verify_ssec_key_match_success() {
let stored_md5 = "abc123".to_string();
let result = verify_ssec_key_match("abc123", Some(&stored_md5));
assert!(result.is_ok());
}
#[test]
fn test_verify_ssec_key_match_mismatch() {
let stored_md5 = "abc123".to_string();
let result = verify_ssec_key_match("xyz789", Some(&stored_md5));
assert!(result.is_err());
}
#[test]
fn test_verify_ssec_key_match_no_stored() {
let result = verify_ssec_key_match("abc123", None);
assert!(result.is_err());
}
}