diff --git a/crates/ecstore/src/bucket/metadata.rs b/crates/ecstore/src/bucket/metadata.rs index d4a9ddac..78ceaf0b 100644 --- a/crates/ecstore/src/bucket/metadata.rs +++ b/crates/ecstore/src/bucket/metadata.rs @@ -25,8 +25,8 @@ use byteorder::{BigEndian, ByteOrder, LittleEndian}; use rmp_serde::Serializer as rmpSerializer; use rustfs_policy::policy::BucketPolicy; use s3s::dto::{ - BucketLifecycleConfiguration, NotificationConfiguration, ObjectLockConfiguration, ReplicationConfiguration, - ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration, + BucketLifecycleConfiguration, CORSConfiguration, NotificationConfiguration, ObjectLockConfiguration, + ReplicationConfiguration, ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration, }; use serde::Serializer; use serde::{Deserialize, Serialize}; @@ -49,6 +49,7 @@ pub const OBJECT_LOCK_CONFIG: &str = "object-lock.xml"; pub const BUCKET_VERSIONING_CONFIG: &str = "versioning.xml"; pub const BUCKET_REPLICATION_CONFIG: &str = "replication.xml"; pub const BUCKET_TARGETS_FILE: &str = "bucket-targets.json"; +pub const BUCKET_CORS_CONFIG: &str = "cors.xml"; #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(rename_all = "PascalCase", default)] @@ -67,6 +68,7 @@ pub struct BucketMetadata { pub replication_config_xml: Vec, pub bucket_targets_config_json: Vec, pub bucket_targets_config_meta_json: Vec, + pub cors_config_xml: Vec, pub policy_config_updated_at: OffsetDateTime, pub object_lock_config_updated_at: OffsetDateTime, @@ -79,6 +81,7 @@ pub struct BucketMetadata { pub notification_config_updated_at: OffsetDateTime, pub bucket_targets_config_updated_at: OffsetDateTime, pub bucket_targets_config_meta_updated_at: OffsetDateTime, + pub cors_config_updated_at: OffsetDateTime, #[serde(skip)] pub new_field_updated_at: OffsetDateTime, @@ -105,6 +108,8 @@ pub struct BucketMetadata { pub bucket_target_config: Option, #[serde(skip)] pub bucket_target_config_meta: Option>, + #[serde(skip)] + pub cors_config: Option, } impl Default for BucketMetadata { @@ -124,6 +129,7 @@ impl Default for BucketMetadata { replication_config_xml: Default::default(), bucket_targets_config_json: Default::default(), bucket_targets_config_meta_json: Default::default(), + cors_config_xml: Default::default(), policy_config_updated_at: OffsetDateTime::UNIX_EPOCH, object_lock_config_updated_at: OffsetDateTime::UNIX_EPOCH, encryption_config_updated_at: OffsetDateTime::UNIX_EPOCH, @@ -135,6 +141,7 @@ impl Default for BucketMetadata { notification_config_updated_at: OffsetDateTime::UNIX_EPOCH, bucket_targets_config_updated_at: OffsetDateTime::UNIX_EPOCH, bucket_targets_config_meta_updated_at: OffsetDateTime::UNIX_EPOCH, + cors_config_updated_at: OffsetDateTime::UNIX_EPOCH, new_field_updated_at: OffsetDateTime::UNIX_EPOCH, policy_config: Default::default(), notification_config: Default::default(), @@ -147,6 +154,7 @@ impl Default for BucketMetadata { replication_config: Default::default(), bucket_target_config: Default::default(), bucket_target_config_meta: Default::default(), + cors_config: Default::default(), } } } @@ -295,6 +303,10 @@ impl BucketMetadata { self.bucket_targets_config_json = data.clone(); self.bucket_targets_config_updated_at = updated; } + BUCKET_CORS_CONFIG => { + self.cors_config_xml = data; + self.cors_config_updated_at = updated; + } _ => return Err(Error::other(format!("config file not found : {config_file}"))), } @@ -365,6 +377,9 @@ impl BucketMetadata { } else { self.bucket_target_config = Some(BucketTargets::default()) } + if !self.cors_config_xml.is_empty() { + self.cors_config = Some(deserialize::(&self.cors_config_xml)?); + } Ok(()) } diff --git a/crates/ecstore/src/bucket/metadata_sys.rs b/crates/ecstore/src/bucket/metadata_sys.rs index dad17b97..9b14857e 100644 --- a/crates/ecstore/src/bucket/metadata_sys.rs +++ b/crates/ecstore/src/bucket/metadata_sys.rs @@ -28,8 +28,8 @@ use rustfs_common::heal_channel::HealOpts; use rustfs_policy::policy::BucketPolicy; use s3s::dto::ReplicationConfiguration; use s3s::dto::{ - BucketLifecycleConfiguration, NotificationConfiguration, ObjectLockConfiguration, ServerSideEncryptionConfiguration, Tagging, - VersioningConfiguration, + BucketLifecycleConfiguration, CORSConfiguration, NotificationConfiguration, ObjectLockConfiguration, + ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration, }; use std::collections::HashSet; use std::sync::OnceLock; @@ -110,6 +110,13 @@ pub async fn get_bucket_targets_config(bucket: &str) -> Result { bucket_meta_sys.get_bucket_targets_config(bucket).await } +pub async fn get_cors_config(bucket: &str) -> Result<(CORSConfiguration, OffsetDateTime)> { + let bucket_meta_sys_lock = get_bucket_metadata_sys()?; + let bucket_meta_sys = bucket_meta_sys_lock.read().await; + + bucket_meta_sys.get_cors_config(bucket).await +} + pub async fn get_tagging_config(bucket: &str) -> Result<(Tagging, OffsetDateTime)> { let bucket_meta_sys_lock = get_bucket_metadata_sys()?; let bucket_meta_sys = bucket_meta_sys_lock.read().await; @@ -500,6 +507,16 @@ impl BucketMetadataSys { } } + pub async fn get_cors_config(&self, bucket: &str) -> Result<(CORSConfiguration, OffsetDateTime)> { + let (bm, _) = self.get_config(bucket).await?; + + if let Some(config) = &bm.cors_config { + Ok((config.clone(), bm.cors_config_updated_at)) + } else { + Err(Error::ConfigNotFound) + } + } + pub async fn created_at(&self, bucket: &str) -> Result { let bm = match self.get_config(bucket).await { Ok((bm, _)) => bm.created, diff --git a/rustfs/src/admin/router.rs b/rustfs/src/admin/router.rs index 09c390cf..b01565b5 100644 --- a/rustfs/src/admin/router.rs +++ b/rustfs/src/admin/router.rs @@ -84,6 +84,7 @@ where { fn is_match(&self, method: &Method, uri: &Uri, headers: &HeaderMap, _: &mut Extensions) -> bool { let path = uri.path(); + // Profiling endpoints if method == Method::GET && (path == PROFILE_CPU_PATH || path == PROFILE_MEMORY_PATH) { return true; @@ -150,6 +151,8 @@ where } async fn call(&self, req: S3Request) -> S3Result> { + // Console requests should be handled by console router first (including OPTIONS) + // Console has its own CORS layer configured if self.console_enabled && is_console_path(req.uri.path()) { if let Some(console_router) = &self.console_router { let mut console_router = console_router.clone(); @@ -164,11 +167,14 @@ where } let uri = format!("{}|{}", &req.method, req.uri.path()); + if let Ok(mat) = self.router.at(&uri) { let op: &T = mat.value; let mut resp = op.call(req, mat.params).await?; resp.status = Some(resp.output.0); - return Ok(resp.map_output(|x| x.1)); + let response = resp.map_output(|x| x.1); + + return Ok(response); } Err(s3_error!(NotImplemented)) diff --git a/rustfs/src/server/cors.rs b/rustfs/src/server/cors.rs new file mode 100644 index 00000000..b01d9034 --- /dev/null +++ b/rustfs/src/server/cors.rs @@ -0,0 +1,40 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! CORS (Cross-Origin Resource Sharing) header name constants. +//! +//! This module provides centralized constants for CORS-related HTTP header names. +//! The http crate doesn't provide pre-defined constants for CORS headers, +//! so we define them here for type safety and maintainability. + +/// CORS response header names +pub mod response { + pub const ACCESS_CONTROL_ALLOW_ORIGIN: &str = "access-control-allow-origin"; + pub const ACCESS_CONTROL_ALLOW_METHODS: &str = "access-control-allow-methods"; + pub const ACCESS_CONTROL_ALLOW_HEADERS: &str = "access-control-allow-headers"; + pub const ACCESS_CONTROL_EXPOSE_HEADERS: &str = "access-control-expose-headers"; + pub const ACCESS_CONTROL_ALLOW_CREDENTIALS: &str = "access-control-allow-credentials"; + pub const ACCESS_CONTROL_MAX_AGE: &str = "access-control-max-age"; +} + +/// CORS request header names +pub mod request { + pub const ACCESS_CONTROL_REQUEST_METHOD: &str = "access-control-request-method"; + pub const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers"; +} + +/// Standard HTTP header names used in CORS processing +pub mod standard { + pub use http::header::{ORIGIN, VARY}; +} diff --git a/rustfs/src/server/http.rs b/rustfs/src/server/http.rs index 57a8f7a4..5cecb319 100644 --- a/rustfs/src/server/http.rs +++ b/rustfs/src/server/http.rs @@ -17,7 +17,11 @@ use super::compress::{CompressionConfig, CompressionPredicate}; use crate::admin; use crate::auth::IAMAuth; use crate::config; -use crate::server::{ReadinessGateLayer, RemoteAddr, ServiceState, ServiceStateManager, hybrid::hybrid, layer::RedirectLayer}; +use crate::server::{ + ReadinessGateLayer, RemoteAddr, ServiceState, ServiceStateManager, + hybrid::hybrid, + layer::{ConditionalCorsLayer, RedirectLayer}, +}; use crate::storage; use crate::storage::tonic_service::make_server; use bytes::Bytes; @@ -48,70 +52,10 @@ use tower::ServiceBuilder; use tower_http::add_extension::AddExtensionLayer; use tower_http::catch_panic::CatchPanicLayer; use tower_http::compression::CompressionLayer; -use tower_http::cors::{AllowOrigin, Any, CorsLayer}; use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}; use tower_http::trace::TraceLayer; use tracing::{Span, debug, error, info, instrument, warn}; -/// Parse CORS allowed origins from configuration -fn parse_cors_origins(origins: Option<&String>) -> CorsLayer { - use http::Method; - - let cors_layer = CorsLayer::new() - .allow_methods([ - Method::GET, - Method::POST, - Method::PUT, - Method::DELETE, - Method::HEAD, - Method::OPTIONS, - ]) - .allow_headers(Any); - - match origins { - Some(origins_str) if origins_str == "*" => cors_layer.allow_origin(Any).expose_headers(Any), - Some(origins_str) => { - let origins: Vec<&str> = origins_str.split(',').map(|s| s.trim()).collect(); - if origins.is_empty() { - warn!("Empty CORS origins provided, using permissive CORS"); - cors_layer.allow_origin(Any).expose_headers(Any) - } else { - // Parse origins with proper error handling - let mut valid_origins = Vec::new(); - for origin in origins { - match origin.parse::() { - Ok(header_value) => { - valid_origins.push(header_value); - } - Err(e) => { - warn!("Invalid CORS origin '{}': {}", origin, e); - } - } - } - - if valid_origins.is_empty() { - warn!("No valid CORS origins found, using permissive CORS"); - cors_layer.allow_origin(Any).expose_headers(Any) - } else { - info!("Endpoint CORS origins configured: {:?}", valid_origins); - cors_layer.allow_origin(AllowOrigin::list(valid_origins)).expose_headers(Any) - } - } - } - None => { - debug!("No CORS origins configured for endpoint, using permissive CORS"); - cors_layer.allow_origin(Any).expose_headers(Any) - } - } -} - -fn get_cors_allowed_origins() -> String { - std::env::var(rustfs_config::ENV_CORS_ALLOWED_ORIGINS) - .unwrap_or_else(|_| rustfs_config::DEFAULT_CORS_ALLOWED_ORIGINS.to_string()) - .parse::() - .unwrap_or(rustfs_config::DEFAULT_CONSOLE_CORS_ALLOWED_ORIGINS.to_string()) -} - pub async fn start_http_server( opt: &config::Opt, worker_state_manager: ServiceStateManager, @@ -273,14 +217,6 @@ pub async fn start_http_server( let (shutdown_tx, mut shutdown_rx) = tokio::sync::broadcast::channel(1); let shutdown_tx_clone = shutdown_tx.clone(); - // Capture CORS configuration for the server loop - let cors_allowed_origins = get_cors_allowed_origins(); - let cors_allowed_origins = if cors_allowed_origins.is_empty() { - None - } else { - Some(cors_allowed_origins) - }; - // Create compression configuration from environment variables let compression_config = CompressionConfig::from_env(); if compression_config.enabled { @@ -294,8 +230,10 @@ pub async fn start_http_server( let is_console = opt.console_enable; tokio::spawn(async move { - // Create CORS layer inside the server loop closure - let cors_layer = parse_cors_origins(cors_allowed_origins.as_ref()); + // Note: CORS layer is removed from global middleware stack + // - S3 API CORS is handled by bucket-level CORS configuration in apply_cors_headers() + // - Console CORS is handled by its own cors_layer in setup_console_middleware_stack() + // This ensures S3 API CORS behavior matches AWS S3 specification #[cfg(unix)] let (mut sigterm_inner, mut sigint_inner) = { @@ -405,7 +343,6 @@ pub async fn start_http_server( let connection_ctx = ConnectionContext { http_server: http_server.clone(), s3_service: s3_service.clone(), - cors_layer: cors_layer.clone(), compression_config: compression_config.clone(), is_console, readiness: readiness.clone(), @@ -521,7 +458,6 @@ async fn setup_tls_acceptor(tls_path: &str) -> Result> { struct ConnectionContext { http_server: Arc>, s3_service: S3Service, - cors_layer: CorsLayer, compression_config: CompressionConfig, is_console: bool, readiness: Arc, @@ -546,7 +482,6 @@ fn process_connection( let ConnectionContext { http_server, s3_service, - cors_layer, compression_config, is_console, readiness, @@ -629,10 +564,15 @@ fn process_connection( }), ) .layer(PropagateRequestIdLayer::x_request_id()) - .layer(cors_layer) // Compress responses based on whitelist configuration // Only compresses when enabled and matches configured extensions/MIME types .layer(CompressionLayer::new().compress_when(CompressionPredicate::new(compression_config))) + // Conditional CORS layer: only applies to S3 API requests (not Admin, not Console) + // Admin has its own CORS handling in router.rs + // Console has its own CORS layer in setup_console_middleware_stack() + // S3 API uses this system default CORS (RUSTFS_CORS_ALLOWED_ORIGINS) + // Bucket-level CORS takes precedence when configured (handled in router.rs for OPTIONS, and in ecfs.rs for actual requests) + .layer(ConditionalCorsLayer::new()) .option_layer(if is_console { Some(RedirectLayer) } else { None }) .service(service); diff --git a/rustfs/src/server/layer.rs b/rustfs/src/server/layer.rs index f324d06b..705798d3 100644 --- a/rustfs/src/server/layer.rs +++ b/rustfs/src/server/layer.rs @@ -12,14 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::admin::console::is_console_path; +use crate::server::cors; use crate::server::hybrid::HybridBody; -use http::{Request as HttpRequest, Response, StatusCode}; +use crate::server::{ADMIN_PREFIX, RPC_PREFIX}; +use crate::storage::ecfs; +use http::{HeaderMap, HeaderValue, Method, Request as HttpRequest, Response, StatusCode}; use hyper::body::Incoming; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; use tower::{Layer, Service}; -use tracing::debug; +use tracing::{debug, info}; /// Redirect layer that redirects browser requests to the console #[derive(Clone)] @@ -89,3 +94,173 @@ where Box::pin(async move { inner.call(req).await.map_err(Into::into) }) } } + +/// Conditional CORS layer that only applies to S3 API requests +/// (not Admin, not Console, not RPC) +#[derive(Clone)] +pub struct ConditionalCorsLayer { + cors_origins: Option, +} + +impl ConditionalCorsLayer { + pub fn new() -> Self { + let cors_origins = std::env::var("RUSTFS_CORS_ALLOWED_ORIGINS").ok().filter(|s| !s.is_empty()); + Self { cors_origins } + } + + /// Exact paths that should be excluded from being treated as S3 paths. + const EXCLUDED_EXACT_PATHS: &'static [&'static str] = &["/health", "/profile/cpu", "/profile/memory"]; + + fn is_s3_path(path: &str) -> bool { + // Exclude Admin, Console, RPC, and configured special paths + !path.starts_with(ADMIN_PREFIX) + && !path.starts_with(RPC_PREFIX) + && !is_console_path(path) + && !Self::EXCLUDED_EXACT_PATHS.contains(&path) + } + + fn apply_cors_headers(&self, request_headers: &HeaderMap, response_headers: &mut HeaderMap) { + let origin = request_headers + .get(cors::standard::ORIGIN) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + let allowed_origin = match (origin, &self.cors_origins) { + (Some(orig), Some(config)) if config == "*" => Some(orig), + (Some(orig), Some(config)) => { + let origins: Vec<&str> = config.split(',').map(|s| s.trim()).collect(); + if origins.contains(&orig.as_str()) { Some(orig) } else { None } + } + (Some(orig), None) => Some(orig), // Default: allow all if not configured + _ => None, + }; + + // Track whether we're using a specific origin (not wildcard) + let using_specific_origin = if let Some(origin) = &allowed_origin { + if let Ok(header_value) = HeaderValue::from_str(origin) { + response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, header_value); + true // Using specific origin, credentials allowed + } else { + false + } + } else { + false + }; + + // Allow all methods by default (S3-compatible set) + response_headers.insert( + cors::response::ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_static("GET, POST, PUT, DELETE, OPTIONS, HEAD"), + ); + + // Allow all headers by default + response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_static("*")); + + // Expose common headers + response_headers.insert( + cors::response::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::from_static("x-request-id, content-type, content-length, etag"), + ); + + // Only set credentials when using a specific origin (not wildcard) + // CORS spec: credentials cannot be used with wildcard origins + if using_specific_origin { + response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true")); + } + } +} + +impl Default for ConditionalCorsLayer { + fn default() -> Self { + Self::new() + } +} + +impl Layer for ConditionalCorsLayer { + type Service = ConditionalCorsService; + + fn layer(&self, inner: S) -> Self::Service { + ConditionalCorsService { + inner, + cors_origins: Arc::new(self.cors_origins.clone()), + } + } +} + +/// Service implementation for conditional CORS +#[derive(Clone)] +pub struct ConditionalCorsService { + inner: S, + cors_origins: Arc>, +} + +impl Service> for ConditionalCorsService +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send + 'static, + ResBody: Default + Send + 'static, +{ + type Response = Response; + type Error = Box; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: HttpRequest) -> Self::Future { + let path = req.uri().path().to_string(); + let method = req.method().clone(); + let request_headers = req.headers().clone(); + let cors_origins = self.cors_origins.clone(); + // Handle OPTIONS preflight requests - return response directly without calling handler + if method == Method::OPTIONS && request_headers.contains_key(cors::standard::ORIGIN) { + info!("OPTIONS preflight request for path: {}", path); + + let path_trimmed = path.trim_start_matches('/'); + let bucket = path_trimmed.split('/').next().unwrap_or("").to_string(); // virtual host style? + let method_clone = method.clone(); + let request_headers_clone = request_headers.clone(); + + return Box::pin(async move { + let mut response = Response::builder().status(StatusCode::OK).body(ResBody::default()).unwrap(); + + if ConditionalCorsLayer::is_s3_path(&path) + && !bucket.is_empty() + && cors_origins.is_some() + && let Some(cors_headers) = ecfs::apply_cors_headers(&bucket, &method_clone, &request_headers_clone).await + { + for (key, value) in cors_headers.iter() { + response.headers_mut().insert(key, value.clone()); + } + return Ok(response); + } + + let cors_layer = ConditionalCorsLayer { + cors_origins: (*cors_origins).clone(), + }; + cors_layer.apply_cors_headers(&request_headers_clone, response.headers_mut()); + + Ok(response) + }); + } + + let mut inner = self.inner.clone(); + Box::pin(async move { + let mut response = inner.call(req).await.map_err(Into::into)?; + + // Apply CORS headers only to S3 API requests (non-OPTIONS) + if request_headers.contains_key(cors::standard::ORIGIN) + && !response.headers().contains_key(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN) + { + let cors_layer = ConditionalCorsLayer { + cors_origins: (*cors_origins).clone(), + }; + cors_layer.apply_cors_headers(&request_headers, response.headers_mut()); + } + + Ok(response) + }) + } +} diff --git a/rustfs/src/server/mod.rs b/rustfs/src/server/mod.rs index c6f72d19..8714fa78 100644 --- a/rustfs/src/server/mod.rs +++ b/rustfs/src/server/mod.rs @@ -15,6 +15,7 @@ mod audit; mod cert; mod compress; +pub mod cors; mod event; mod http; mod hybrid; diff --git a/rustfs/src/storage/access.rs b/rustfs/src/storage/access.rs index e394c68f..79515cdc 100644 --- a/rustfs/src/storage/access.rs +++ b/rustfs/src/storage/access.rs @@ -342,7 +342,7 @@ impl S3Access for FS { let req_info = req.extensions.get_mut::().expect("ReqInfo not found"); req_info.bucket = Some(req.input.bucket.clone()); - authorize_request(req, Action::S3Action(S3Action::PutBucketCorsAction)).await + authorize_request(req, Action::S3Action(S3Action::DeleteBucketCorsAction)).await } /// Checks whether the DeleteBucketEncryption request has accesses to the resources. diff --git a/rustfs/src/storage/ecfs.rs b/rustfs/src/storage/ecfs.rs index 8f21ff57..a1981d07 100644 --- a/rustfs/src/storage/ecfs.rs +++ b/rustfs/src/storage/ecfs.rs @@ -18,6 +18,7 @@ use crate::config::workload_profiles::{ }; use crate::error::ApiError; use crate::server::RemoteAddr; +use crate::server::cors; use crate::storage::concurrency::{ CachedGetObject, ConcurrencyManager, GetObjectGuard, get_concurrency_aware_buffer_size, get_concurrency_manager, }; @@ -48,8 +49,8 @@ use rustfs_ecstore::{ lifecycle::{self, Lifecycle, TransitionOptions}, }, metadata::{ - BUCKET_LIFECYCLE_CONFIG, BUCKET_NOTIFICATION_CONFIG, BUCKET_POLICY_CONFIG, BUCKET_REPLICATION_CONFIG, - BUCKET_SSECONFIG, BUCKET_TAGGING_CONFIG, BUCKET_VERSIONING_CONFIG, OBJECT_LOCK_CONFIG, + BUCKET_CORS_CONFIG, BUCKET_LIFECYCLE_CONFIG, BUCKET_NOTIFICATION_CONFIG, BUCKET_POLICY_CONFIG, + BUCKET_REPLICATION_CONFIG, BUCKET_SSECONFIG, BUCKET_TAGGING_CONFIG, BUCKET_VERSIONING_CONFIG, OBJECT_LOCK_CONFIG, }, metadata_sys, metadata_sys::get_replication_config, @@ -818,6 +819,205 @@ async fn get_validated_store(bucket: &str) -> S3Result bool { + headers.contains_key(cors::standard::ORIGIN) +} + +/// Apply CORS headers to response based on bucket CORS configuration and request origin +/// +/// This function: +/// 1. Reads the Origin header from the request +/// 2. Retrieves the bucket's CORS configuration +/// 3. Matches the origin against CORS rules +/// 4. Validates AllowedHeaders if request headers are present +/// 5. Returns headers to add to the response if a match is found +/// +/// Note: This function should only be called if `needs_cors_processing()` returns true +/// to avoid unnecessary overhead for non-CORS requests. +pub(crate) async fn apply_cors_headers(bucket: &str, method: &http::Method, headers: &HeaderMap) -> Option { + use http::HeaderValue; + + // Get Origin header from request + let origin = headers.get(cors::standard::ORIGIN)?.to_str().ok()?; + + // Get CORS configuration for the bucket + let cors_config = match metadata_sys::get_cors_config(bucket).await { + Ok((config, _)) => config, + Err(_) => return None, // No CORS config, no headers to add + }; + + // Early return if no CORS rules configured + if cors_config.cors_rules.is_empty() { + return None; + } + + // Check if method is supported and get its string representation + const SUPPORTED_METHODS: &[&str] = &["GET", "PUT", "POST", "DELETE", "HEAD", "OPTIONS"]; + let method_str = method.as_str(); + if !SUPPORTED_METHODS.contains(&method_str) { + return None; + } + + // For OPTIONS (preflight) requests, check Access-Control-Request-Method + let is_preflight = method == http::Method::OPTIONS; + let requested_method = if is_preflight { + headers + .get(cors::request::ACCESS_CONTROL_REQUEST_METHOD) + .and_then(|v| v.to_str().ok()) + .unwrap_or(method_str) + } else { + method_str + }; + + // Get requested headers from preflight request + let requested_headers = if is_preflight { + headers + .get(cors::request::ACCESS_CONTROL_REQUEST_HEADERS) + .and_then(|v| v.to_str().ok()) + .map(|h| h.split(',').map(|s| s.trim().to_lowercase()).collect::>()) + } else { + None + }; + + // Find matching CORS rule + for rule in cors_config.cors_rules.iter() { + // Check if origin matches + let origin_matches = rule.allowed_origins.iter().any(|allowed_origin| { + if allowed_origin == "*" { + true + } else { + // Exact match or pattern match (support wildcards like https://*.example.com) + allowed_origin == origin || matches_origin_pattern(allowed_origin, origin) + } + }); + + if !origin_matches { + continue; + } + + // Check if method is allowed + let method_allowed = rule + .allowed_methods + .iter() + .any(|allowed_method| allowed_method.as_str() == requested_method); + + if !method_allowed { + continue; + } + + // Validate AllowedHeaders if present in the request + if let Some(ref req_headers) = requested_headers { + if let Some(ref allowed_headers) = rule.allowed_headers { + // Check if all requested headers are allowed + let all_headers_allowed = req_headers.iter().all(|req_header| { + allowed_headers.iter().any(|allowed_header| { + let allowed_lower = allowed_header.to_lowercase(); + // "*" allows all headers, or exact match + allowed_lower == "*" || allowed_lower == *req_header + }) + }); + + if !all_headers_allowed { + // If not all headers are allowed, skip this rule + continue; + } + } else if !req_headers.is_empty() { + // If no AllowedHeaders specified but headers were requested, skip this rule + // Unless the rule explicitly allows all headers + continue; + } + } + + // Found matching rule, build response headers + let mut response_headers = HeaderMap::new(); + + // Access-Control-Allow-Origin + // If origin is "*", use "*", otherwise echo back the origin + let has_wildcard_origin = rule.allowed_origins.iter().any(|o| o == "*"); + if has_wildcard_origin { + response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); + } else if let Ok(origin_value) = HeaderValue::from_str(origin) { + response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, origin_value); + } + + // Vary: Origin (required for caching, except when using wildcard) + if !has_wildcard_origin { + response_headers.insert(cors::standard::VARY, HeaderValue::from_static("Origin")); + } + + // Access-Control-Allow-Methods (required for preflight) + if is_preflight || !rule.allowed_methods.is_empty() { + let methods_str = rule.allowed_methods.iter().map(|m| m.as_str()).collect::>().join(", "); + if let Ok(methods_value) = HeaderValue::from_str(&methods_str) { + response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_METHODS, methods_value); + } + } + + // Access-Control-Allow-Headers (required for preflight if headers were requested) + if is_preflight && let Some(ref allowed_headers) = rule.allowed_headers { + let headers_str = allowed_headers.iter().map(|h| h.as_str()).collect::>().join(", "); + if let Ok(headers_value) = HeaderValue::from_str(&headers_str) { + response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_HEADERS, headers_value); + } + } + + // Access-Control-Expose-Headers (for actual requests) + if !is_preflight && let Some(ref expose_headers) = rule.expose_headers { + let expose_headers_str = expose_headers.iter().map(|h| h.as_str()).collect::>().join(", "); + if let Ok(expose_value) = HeaderValue::from_str(&expose_headers_str) { + response_headers.insert(cors::response::ACCESS_CONTROL_EXPOSE_HEADERS, expose_value); + } + } + + // Access-Control-Max-Age (for preflight requests) + if is_preflight + && let Some(max_age) = rule.max_age_seconds + && let Ok(max_age_value) = HeaderValue::from_str(&max_age.to_string()) + { + response_headers.insert(cors::response::ACCESS_CONTROL_MAX_AGE, max_age_value); + } + + return Some(response_headers); + } + + None // No matching rule found +} +/// Check if an origin matches a pattern (supports wildcards like https://*.example.com) +fn matches_origin_pattern(pattern: &str, origin: &str) -> bool { + // Simple wildcard matching: * matches any sequence + if pattern.contains('*') { + let pattern_parts: Vec<&str> = pattern.split('*').collect(); + if pattern_parts.len() == 2 { + origin.starts_with(pattern_parts[0]) && origin.ends_with(pattern_parts[1]) + } else { + false + } + } else { + pattern == origin + } +} + +/// Wrap S3Response with CORS headers if needed +/// This function performs a lightweight check first to avoid unnecessary CORS processing +/// for non-CORS requests (requests without Origin header) +async fn wrap_response_with_cors(bucket: &str, method: &http::Method, headers: &HeaderMap, output: T) -> S3Response { + let mut response = S3Response::new(output); + + // Quick check: only process CORS if Origin header is present + if needs_cors_processing(headers) + && let Some(cors_headers) = apply_cors_headers(bucket, method, headers).await + { + for (key, value) in cors_headers.iter() { + response.headers.insert(key, value.clone()); + } + } + + response +} + #[async_trait::async_trait] impl S3 for FS { #[instrument( @@ -2598,7 +2798,8 @@ impl S3 for FS { cache_key, response_content_length, total_duration, optimal_buffer_size ); - let result = Ok(S3Response::new(output)); + let response = wrap_response_with_cors(&bucket, &req.method, &req.headers, output).await; + let result = Ok(response); let _ = helper.complete(&result); result } @@ -2843,7 +3044,14 @@ impl S3 for FS { let version_id = req.input.version_id.clone().unwrap_or_default(); helper = helper.object(event_info).version_id(version_id); - let result = Ok(S3Response::new(output)); + // NOTE ON CORS: + // Bucket-level CORS headers are intentionally applied only for object retrieval + // operations (GET/HEAD) via `wrap_response_with_cors`. Other S3 operations that + // interact with objects (PUT/POST/DELETE/LIST, etc.) rely on the system-level + // CORS layer instead. In case both are applicable, this bucket-level CORS logic + // takes precedence for these read operations. + let response = wrap_response_with_cors(&bucket, &req.method, &req.headers, output).await; + let result = Ok(response); let _ = helper.complete(&result); result @@ -4713,6 +4921,82 @@ impl S3 for FS { Ok(S3Response::new(DeleteBucketTaggingOutput {})) } + #[instrument(level = "debug", skip(self))] + async fn get_bucket_cors(&self, req: S3Request) -> S3Result> { + let bucket = req.input.bucket.clone(); + // check bucket exists. + let _bucket = self + .head_bucket(req.map_input(|input| HeadBucketInput { + bucket: input.bucket, + expected_bucket_owner: None, + })) + .await?; + + let cors_configuration = match metadata_sys::get_cors_config(&bucket).await { + Ok((config, _)) => config, + Err(err) => { + if err == StorageError::ConfigNotFound { + return Err(S3Error::with_message( + S3ErrorCode::NoSuchCORSConfiguration, + "The CORS configuration does not exist".to_string(), + )); + } + warn!("get_cors_config err {:?}", &err); + return Err(ApiError::from(err).into()); + } + }; + + Ok(S3Response::new(GetBucketCorsOutput { + cors_rules: Some(cors_configuration.cors_rules), + })) + } + + #[instrument(level = "debug", skip(self))] + async fn put_bucket_cors(&self, req: S3Request) -> S3Result> { + let PutBucketCorsInput { + bucket, + cors_configuration, + .. + } = req.input; + + let Some(store) = new_object_layer_fn() else { + return Err(S3Error::with_message(S3ErrorCode::InternalError, "Not init".to_string())); + }; + + store + .get_bucket_info(&bucket, &BucketOptions::default()) + .await + .map_err(ApiError::from)?; + + let data = try_!(serialize(&cors_configuration)); + + metadata_sys::update(&bucket, BUCKET_CORS_CONFIG, data) + .await + .map_err(ApiError::from)?; + + Ok(S3Response::new(PutBucketCorsOutput::default())) + } + + #[instrument(level = "debug", skip(self))] + async fn delete_bucket_cors(&self, req: S3Request) -> S3Result> { + let DeleteBucketCorsInput { bucket, .. } = req.input; + + let Some(store) = new_object_layer_fn() else { + return Err(S3Error::with_message(S3ErrorCode::InternalError, "Not init".to_string())); + }; + + store + .get_bucket_info(&bucket, &BucketOptions::default()) + .await + .map_err(ApiError::from)?; + + metadata_sys::delete(&bucket, BUCKET_CORS_CONFIG) + .await + .map_err(ApiError::from)?; + + Ok(S3Response::new(DeleteBucketCorsOutput {})) + } + #[instrument(level = "debug", skip(self, req))] async fn put_object_tagging(&self, req: S3Request) -> S3Result> { let start_time = std::time::Instant::now(); @@ -6557,4 +6841,201 @@ mod tests { assert!(filtered_version_marker.is_some()); assert_eq!(filtered_version_marker.unwrap(), "null"); } + + #[test] + fn test_matches_origin_pattern_exact_match() { + // Test exact match + assert!(matches_origin_pattern("https://example.com", "https://example.com")); + assert!(matches_origin_pattern("http://localhost:3000", "http://localhost:3000")); + assert!(!matches_origin_pattern("https://example.com", "https://other.com")); + } + + #[test] + fn test_matches_origin_pattern_wildcard() { + // Test wildcard pattern matching (S3 CORS supports * as subdomain wildcard) + assert!(matches_origin_pattern("https://*.example.com", "https://app.example.com")); + assert!(matches_origin_pattern("https://*.example.com", "https://api.example.com")); + assert!(matches_origin_pattern("https://*.example.com", "https://subdomain.example.com")); + + // Test wildcard at start (matches any domain) + assert!(matches_origin_pattern("https://*", "https://example.com")); + assert!(matches_origin_pattern("https://*", "https://any-domain.com")); + + // Test wildcard at end (matches any protocol) + assert!(matches_origin_pattern("*://example.com", "https://example.com")); + assert!(matches_origin_pattern("*://example.com", "http://example.com")); + + // Test invalid wildcard patterns (should not match) + assert!(!matches_origin_pattern("https://*.*.com", "https://app.example.com")); // Multiple wildcards (invalid pattern) + // Note: "https://*example.com" actually matches "https://app.example.com" with our current implementation + // because it splits on * and checks starts_with/ends_with. This is a limitation but acceptable + // for S3 CORS which typically uses patterns like "https://*.example.com" + } + + #[test] + fn test_matches_origin_pattern_no_wildcard() { + // Test patterns without wildcards + assert!(matches_origin_pattern("https://example.com", "https://example.com")); + assert!(!matches_origin_pattern("https://example.com", "https://example.org")); + assert!(!matches_origin_pattern("http://example.com", "https://example.com")); // Different protocol + } + + #[test] + fn test_matches_origin_pattern_edge_cases() { + // Test edge cases + assert!(!matches_origin_pattern("", "https://example.com")); // Empty pattern + assert!(!matches_origin_pattern("https://example.com", "")); // Empty origin + assert!(matches_origin_pattern("", "")); // Both empty + assert!(!matches_origin_pattern("https://example.com", "http://example.com")); // Protocol mismatch + } + + #[test] + fn test_cors_headers_validation() { + use http::HeaderMap; + + // Test case 1: Validate header name case-insensitivity + let mut headers = HeaderMap::new(); + headers.insert("access-control-request-headers", "Content-Type,X-Custom-Header".parse().unwrap()); + + let req_headers_str = headers + .get("access-control-request-headers") + .and_then(|v| v.to_str().ok()) + .unwrap(); + let req_headers: Vec = req_headers_str.split(',').map(|s| s.trim().to_lowercase()).collect(); + + // Headers should be lowercased for comparison + assert_eq!(req_headers, vec!["content-type", "x-custom-header"]); + + // Test case 2: Wildcard matching + let allowed_headers = ["*".to_string()]; + let all_allowed = req_headers.iter().all(|req_header| { + allowed_headers + .iter() + .any(|allowed| allowed.to_lowercase() == "*" || allowed.to_lowercase() == *req_header) + }); + assert!(all_allowed, "Wildcard should allow all headers"); + + // Test case 3: Specific header matching + let allowed_headers = ["content-type".to_string(), "x-custom-header".to_string()]; + let all_allowed = req_headers + .iter() + .all(|req_header| allowed_headers.iter().any(|allowed| allowed.to_lowercase() == *req_header)); + assert!(all_allowed, "All requested headers should be allowed"); + + // Test case 4: Disallowed header + let req_headers = ["content-type".to_string(), "x-forbidden-header".to_string()]; + let allowed_headers = ["content-type".to_string()]; + let all_allowed = req_headers + .iter() + .all(|req_header| allowed_headers.iter().any(|allowed| allowed.to_lowercase() == *req_header)); + assert!(!all_allowed, "Forbidden header should not be allowed"); + } + + #[test] + fn test_cors_response_headers_structure() { + use http::{HeaderMap, HeaderValue}; + + let mut cors_headers = HeaderMap::new(); + + // Simulate building CORS response headers + let origin = "https://example.com"; + let methods = ["GET", "PUT", "POST"]; + let allowed_headers = ["Content-Type", "Authorization"]; + let expose_headers = ["ETag", "x-amz-version-id"]; + let max_age = 3600; + + // Add headers + cors_headers.insert("access-control-allow-origin", HeaderValue::from_str(origin).unwrap()); + cors_headers.insert("vary", HeaderValue::from_static("Origin")); + + let methods_str = methods.join(", "); + cors_headers.insert("access-control-allow-methods", HeaderValue::from_str(&methods_str).unwrap()); + + let headers_str = allowed_headers.join(", "); + cors_headers.insert("access-control-allow-headers", HeaderValue::from_str(&headers_str).unwrap()); + + let expose_str = expose_headers.join(", "); + cors_headers.insert("access-control-expose-headers", HeaderValue::from_str(&expose_str).unwrap()); + + cors_headers.insert("access-control-max-age", HeaderValue::from_str(&max_age.to_string()).unwrap()); + + // Verify all headers are present + assert_eq!(cors_headers.get("access-control-allow-origin").unwrap(), origin); + assert_eq!(cors_headers.get("vary").unwrap(), "Origin"); + assert_eq!(cors_headers.get("access-control-allow-methods").unwrap(), "GET, PUT, POST"); + assert_eq!(cors_headers.get("access-control-allow-headers").unwrap(), "Content-Type, Authorization"); + assert_eq!(cors_headers.get("access-control-expose-headers").unwrap(), "ETag, x-amz-version-id"); + assert_eq!(cors_headers.get("access-control-max-age").unwrap(), "3600"); + } + + #[test] + fn test_cors_preflight_vs_actual_request() { + use http::Method; + + // Test that we can distinguish preflight from actual requests + let preflight_method = Method::OPTIONS; + let actual_method = Method::PUT; + + assert_eq!(preflight_method, Method::OPTIONS); + assert_ne!(actual_method, Method::OPTIONS); + + // Preflight should check Access-Control-Request-Method + // Actual request should use the actual method + let is_preflight_1 = preflight_method == Method::OPTIONS; + let is_preflight_2 = actual_method == Method::OPTIONS; + + assert!(is_preflight_1); + assert!(!is_preflight_2); + } + + #[tokio::test] + async fn test_apply_cors_headers_no_origin() { + // Test when no Origin header is present + let headers = HeaderMap::new(); + let method = http::Method::GET; + + // Should return None when no origin header + let result = apply_cors_headers("test-bucket", &method, &headers).await; + assert!(result.is_none(), "Should return None when no Origin header"); + } + + #[tokio::test] + async fn test_apply_cors_headers_no_cors_config() { + // Test when bucket has no CORS configuration + let mut headers = HeaderMap::new(); + headers.insert("origin", "https://example.com".parse().unwrap()); + let method = http::Method::GET; + + // Should return None when no CORS config exists + // Note: This test may fail if test-bucket actually has CORS config + // In a real scenario, we'd use a mock or ensure the bucket doesn't exist + let _result = apply_cors_headers("non-existent-bucket-for-testing", &method, &headers).await; + // Result depends on whether bucket exists and has CORS config + // This is expected behavior - we just verify it doesn't panic + } + + #[tokio::test] + async fn test_apply_cors_headers_unsupported_method() { + // Test with unsupported HTTP method + let mut headers = HeaderMap::new(); + headers.insert("origin", "https://example.com".parse().unwrap()); + let method = http::Method::PATCH; // Unsupported method + + let result = apply_cors_headers("test-bucket", &method, &headers).await; + assert!(result.is_none(), "Should return None for unsupported methods"); + } + + #[test] + fn test_matches_origin_pattern_complex_wildcards() { + // Test more complex wildcard scenarios + assert!(matches_origin_pattern("https://*.example.com", "https://sub.example.com")); + // Note: "https://*.example.com" matches "https://api.sub.example.com" with our implementation + // because it only checks starts_with and ends_with. Real S3 might be more strict. + + // Test wildcard in middle position + // Our implementation allows this, but it's not standard S3 CORS pattern + // The pattern "https://example.*.com" splits to ["https://example.", ".com"] + // and "https://example.sub.com" matches because it starts with "https://example." and ends with ".com" + // This is acceptable for our use case as S3 CORS typically uses "https://*.example.com" format + } }