Merge branch 'main' of github.com:rustfs/rustfs into fix/axum-trusted-proxies

This commit is contained in:
houseme
2026-01-16 11:46:30 +08:00
10 changed files with 805 additions and 100 deletions

View File

@@ -25,8 +25,8 @@ use byteorder::{BigEndian, ByteOrder, LittleEndian};
use rmp_serde::Serializer as rmpSerializer;
use rustfs_policy::policy::BucketPolicy;
use s3s::dto::{
BucketLifecycleConfiguration, NotificationConfiguration, ObjectLockConfiguration, ReplicationConfiguration,
ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration,
BucketLifecycleConfiguration, CORSConfiguration, NotificationConfiguration, ObjectLockConfiguration,
ReplicationConfiguration, ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration,
};
use serde::Serializer;
use serde::{Deserialize, Serialize};
@@ -49,6 +49,7 @@ pub const OBJECT_LOCK_CONFIG: &str = "object-lock.xml";
pub const BUCKET_VERSIONING_CONFIG: &str = "versioning.xml";
pub const BUCKET_REPLICATION_CONFIG: &str = "replication.xml";
pub const BUCKET_TARGETS_FILE: &str = "bucket-targets.json";
pub const BUCKET_CORS_CONFIG: &str = "cors.xml";
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(rename_all = "PascalCase", default)]
@@ -67,6 +68,7 @@ pub struct BucketMetadata {
pub replication_config_xml: Vec<u8>,
pub bucket_targets_config_json: Vec<u8>,
pub bucket_targets_config_meta_json: Vec<u8>,
pub cors_config_xml: Vec<u8>,
pub policy_config_updated_at: OffsetDateTime,
pub object_lock_config_updated_at: OffsetDateTime,
@@ -79,6 +81,7 @@ pub struct BucketMetadata {
pub notification_config_updated_at: OffsetDateTime,
pub bucket_targets_config_updated_at: OffsetDateTime,
pub bucket_targets_config_meta_updated_at: OffsetDateTime,
pub cors_config_updated_at: OffsetDateTime,
#[serde(skip)]
pub new_field_updated_at: OffsetDateTime,
@@ -105,6 +108,8 @@ pub struct BucketMetadata {
pub bucket_target_config: Option<BucketTargets>,
#[serde(skip)]
pub bucket_target_config_meta: Option<HashMap<String, String>>,
#[serde(skip)]
pub cors_config: Option<CORSConfiguration>,
}
impl Default for BucketMetadata {
@@ -124,6 +129,7 @@ impl Default for BucketMetadata {
replication_config_xml: Default::default(),
bucket_targets_config_json: Default::default(),
bucket_targets_config_meta_json: Default::default(),
cors_config_xml: Default::default(),
policy_config_updated_at: OffsetDateTime::UNIX_EPOCH,
object_lock_config_updated_at: OffsetDateTime::UNIX_EPOCH,
encryption_config_updated_at: OffsetDateTime::UNIX_EPOCH,
@@ -135,6 +141,7 @@ impl Default for BucketMetadata {
notification_config_updated_at: OffsetDateTime::UNIX_EPOCH,
bucket_targets_config_updated_at: OffsetDateTime::UNIX_EPOCH,
bucket_targets_config_meta_updated_at: OffsetDateTime::UNIX_EPOCH,
cors_config_updated_at: OffsetDateTime::UNIX_EPOCH,
new_field_updated_at: OffsetDateTime::UNIX_EPOCH,
policy_config: Default::default(),
notification_config: Default::default(),
@@ -147,6 +154,7 @@ impl Default for BucketMetadata {
replication_config: Default::default(),
bucket_target_config: Default::default(),
bucket_target_config_meta: Default::default(),
cors_config: Default::default(),
}
}
}
@@ -295,6 +303,10 @@ impl BucketMetadata {
self.bucket_targets_config_json = data.clone();
self.bucket_targets_config_updated_at = updated;
}
BUCKET_CORS_CONFIG => {
self.cors_config_xml = data;
self.cors_config_updated_at = updated;
}
_ => return Err(Error::other(format!("config file not found : {config_file}"))),
}
@@ -365,6 +377,9 @@ impl BucketMetadata {
} else {
self.bucket_target_config = Some(BucketTargets::default())
}
if !self.cors_config_xml.is_empty() {
self.cors_config = Some(deserialize::<CORSConfiguration>(&self.cors_config_xml)?);
}
Ok(())
}

View File

@@ -28,8 +28,8 @@ use rustfs_common::heal_channel::HealOpts;
use rustfs_policy::policy::BucketPolicy;
use s3s::dto::ReplicationConfiguration;
use s3s::dto::{
BucketLifecycleConfiguration, NotificationConfiguration, ObjectLockConfiguration, ServerSideEncryptionConfiguration, Tagging,
VersioningConfiguration,
BucketLifecycleConfiguration, CORSConfiguration, NotificationConfiguration, ObjectLockConfiguration,
ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration,
};
use std::collections::HashSet;
use std::sync::OnceLock;
@@ -110,6 +110,13 @@ pub async fn get_bucket_targets_config(bucket: &str) -> Result<BucketTargets> {
bucket_meta_sys.get_bucket_targets_config(bucket).await
}
pub async fn get_cors_config(bucket: &str) -> Result<(CORSConfiguration, OffsetDateTime)> {
let bucket_meta_sys_lock = get_bucket_metadata_sys()?;
let bucket_meta_sys = bucket_meta_sys_lock.read().await;
bucket_meta_sys.get_cors_config(bucket).await
}
pub async fn get_tagging_config(bucket: &str) -> Result<(Tagging, OffsetDateTime)> {
let bucket_meta_sys_lock = get_bucket_metadata_sys()?;
let bucket_meta_sys = bucket_meta_sys_lock.read().await;
@@ -500,6 +507,16 @@ impl BucketMetadataSys {
}
}
pub async fn get_cors_config(&self, bucket: &str) -> Result<(CORSConfiguration, OffsetDateTime)> {
let (bm, _) = self.get_config(bucket).await?;
if let Some(config) = &bm.cors_config {
Ok((config.clone(), bm.cors_config_updated_at))
} else {
Err(Error::ConfigNotFound)
}
}
pub async fn created_at(&self, bucket: &str) -> Result<OffsetDateTime> {
let bm = match self.get_config(bucket).await {
Ok((bm, _)) => bm.created,

View File

@@ -844,7 +844,11 @@ impl LocalDisk {
self.write_all_internal(&tmp_file_path, InternalBuf::Ref(buf), sync, &tmp_volume_dir)
.await?;
rename_all(tmp_file_path, file_path, volume_dir).await
rename_all(tmp_file_path, &file_path, volume_dir).await?;
// Invalidate cache after successful write
get_global_file_cache().invalidate(&file_path).await;
Ok(())
}
// write_all_public for trail

View File

@@ -84,6 +84,7 @@ where
{
fn is_match(&self, method: &Method, uri: &Uri, headers: &HeaderMap, _: &mut Extensions) -> bool {
let path = uri.path();
// Profiling endpoints
if method == Method::GET && (path == PROFILE_CPU_PATH || path == PROFILE_MEMORY_PATH) {
return true;
@@ -150,6 +151,8 @@ where
}
async fn call(&self, req: S3Request<Body>) -> S3Result<S3Response<Body>> {
// Console requests should be handled by console router first (including OPTIONS)
// Console has its own CORS layer configured
if self.console_enabled && is_console_path(req.uri.path()) {
if let Some(console_router) = &self.console_router {
let mut console_router = console_router.clone();
@@ -164,11 +167,14 @@ where
}
let uri = format!("{}|{}", &req.method, req.uri.path());
if let Ok(mat) = self.router.at(&uri) {
let op: &T = mat.value;
let mut resp = op.call(req, mat.params).await?;
resp.status = Some(resp.output.0);
return Ok(resp.map_output(|x| x.1));
let response = resp.map_output(|x| x.1);
return Ok(response);
}
Err(s3_error!(NotImplemented))

40
rustfs/src/server/cors.rs Normal file
View File

@@ -0,0 +1,40 @@
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! CORS (Cross-Origin Resource Sharing) header name constants.
//!
//! This module provides centralized constants for CORS-related HTTP header names.
//! The http crate doesn't provide pre-defined constants for CORS headers,
//! so we define them here for type safety and maintainability.
/// CORS response header names
pub mod response {
pub const ACCESS_CONTROL_ALLOW_ORIGIN: &str = "access-control-allow-origin";
pub const ACCESS_CONTROL_ALLOW_METHODS: &str = "access-control-allow-methods";
pub const ACCESS_CONTROL_ALLOW_HEADERS: &str = "access-control-allow-headers";
pub const ACCESS_CONTROL_EXPOSE_HEADERS: &str = "access-control-expose-headers";
pub const ACCESS_CONTROL_ALLOW_CREDENTIALS: &str = "access-control-allow-credentials";
pub const ACCESS_CONTROL_MAX_AGE: &str = "access-control-max-age";
}
/// CORS request header names
pub mod request {
pub const ACCESS_CONTROL_REQUEST_METHOD: &str = "access-control-request-method";
pub const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers";
}
/// Standard HTTP header names used in CORS processing
pub mod standard {
pub use http::header::{ORIGIN, VARY};
}

View File

@@ -17,7 +17,11 @@ use super::compress::{CompressionConfig, CompressionPredicate};
use crate::admin;
use crate::auth::IAMAuth;
use crate::config;
use crate::server::{ReadinessGateLayer, RemoteAddr, ServiceState, ServiceStateManager, hybrid::hybrid, layer::RedirectLayer};
use crate::server::{
ReadinessGateLayer, RemoteAddr, ServiceState, ServiceStateManager,
hybrid::hybrid,
layer::{ConditionalCorsLayer, RedirectLayer},
};
use crate::storage;
use crate::storage::tonic_service::make_server;
use bytes::Bytes;
@@ -49,70 +53,10 @@ use tower::ServiceBuilder;
use tower_http::add_extension::AddExtensionLayer;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::trace::TraceLayer;
use tracing::{Span, debug, error, info, instrument, warn};
/// Parse CORS allowed origins from configuration
fn parse_cors_origins(origins: Option<&String>) -> CorsLayer {
use http::Method;
let cors_layer = CorsLayer::new()
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::HEAD,
Method::OPTIONS,
])
.allow_headers(Any);
match origins {
Some(origins_str) if origins_str == "*" => cors_layer.allow_origin(Any).expose_headers(Any),
Some(origins_str) => {
let origins: Vec<&str> = origins_str.split(',').map(|s| s.trim()).collect();
if origins.is_empty() {
warn!("Empty CORS origins provided, using permissive CORS");
cors_layer.allow_origin(Any).expose_headers(Any)
} else {
// Parse origins with proper error handling
let mut valid_origins = Vec::new();
for origin in origins {
match origin.parse::<http::HeaderValue>() {
Ok(header_value) => {
valid_origins.push(header_value);
}
Err(e) => {
warn!("Invalid CORS origin '{}': {}", origin, e);
}
}
}
if valid_origins.is_empty() {
warn!("No valid CORS origins found, using permissive CORS");
cors_layer.allow_origin(Any).expose_headers(Any)
} else {
info!("Endpoint CORS origins configured: {:?}", valid_origins);
cors_layer.allow_origin(AllowOrigin::list(valid_origins)).expose_headers(Any)
}
}
}
None => {
debug!("No CORS origins configured for endpoint, using permissive CORS");
cors_layer.allow_origin(Any).expose_headers(Any)
}
}
}
fn get_cors_allowed_origins() -> String {
std::env::var(rustfs_config::ENV_CORS_ALLOWED_ORIGINS)
.unwrap_or_else(|_| rustfs_config::DEFAULT_CORS_ALLOWED_ORIGINS.to_string())
.parse::<String>()
.unwrap_or(rustfs_config::DEFAULT_CONSOLE_CORS_ALLOWED_ORIGINS.to_string())
}
pub async fn start_http_server(
opt: &config::Opt,
worker_state_manager: ServiceStateManager,
@@ -274,14 +218,6 @@ pub async fn start_http_server(
let (shutdown_tx, mut shutdown_rx) = tokio::sync::broadcast::channel(1);
let shutdown_tx_clone = shutdown_tx.clone();
// Capture CORS configuration for the server loop
let cors_allowed_origins = get_cors_allowed_origins();
let cors_allowed_origins = if cors_allowed_origins.is_empty() {
None
} else {
Some(cors_allowed_origins)
};
// Create compression configuration from environment variables
let compression_config = CompressionConfig::from_env();
if compression_config.enabled {
@@ -295,8 +231,10 @@ pub async fn start_http_server(
let is_console = opt.console_enable;
tokio::spawn(async move {
// Create CORS layer inside the server loop closure
let cors_layer = parse_cors_origins(cors_allowed_origins.as_ref());
// Note: CORS layer is removed from global middleware stack
// - S3 API CORS is handled by bucket-level CORS configuration in apply_cors_headers()
// - Console CORS is handled by its own cors_layer in setup_console_middleware_stack()
// This ensures S3 API CORS behavior matches AWS S3 specification
#[cfg(unix)]
let (mut sigterm_inner, mut sigint_inner) = {
@@ -406,7 +344,6 @@ pub async fn start_http_server(
let connection_ctx = ConnectionContext {
http_server: http_server.clone(),
s3_service: s3_service.clone(),
cors_layer: cors_layer.clone(),
compression_config: compression_config.clone(),
is_console,
readiness: readiness.clone(),
@@ -522,7 +459,6 @@ async fn setup_tls_acceptor(tls_path: &str) -> Result<Option<TlsAcceptor>> {
struct ConnectionContext {
http_server: Arc<ConnBuilder<TokioExecutor>>,
s3_service: S3Service,
cors_layer: CorsLayer,
compression_config: CompressionConfig,
is_console: bool,
readiness: Arc<GlobalReadiness>,
@@ -547,7 +483,6 @@ fn process_connection(
let ConnectionContext {
http_server,
s3_service,
cors_layer,
compression_config,
is_console,
readiness,
@@ -636,10 +571,15 @@ fn process_connection(
}),
)
.layer(PropagateRequestIdLayer::x_request_id())
.layer(cors_layer)
// Compress responses based on whitelist configuration
// Only compresses when enabled and matches configured extensions/MIME types
.layer(CompressionLayer::new().compress_when(CompressionPredicate::new(compression_config)))
// Conditional CORS layer: only applies to S3 API requests (not Admin, not Console)
// Admin has its own CORS handling in router.rs
// Console has its own CORS layer in setup_console_middleware_stack()
// S3 API uses this system default CORS (RUSTFS_CORS_ALLOWED_ORIGINS)
// Bucket-level CORS takes precedence when configured (handled in router.rs for OPTIONS, and in ecfs.rs for actual requests)
.layer(ConditionalCorsLayer::new())
.option_layer(if is_console { Some(RedirectLayer) } else { None })
.service(service);

View File

@@ -12,14 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::admin::console::is_console_path;
use crate::server::cors;
use crate::server::hybrid::HybridBody;
use http::{Request as HttpRequest, Response, StatusCode};
use crate::server::{ADMIN_PREFIX, RPC_PREFIX};
use crate::storage::ecfs;
use http::{HeaderMap, HeaderValue, Method, Request as HttpRequest, Response, StatusCode};
use hyper::body::Incoming;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};
use tracing::debug;
use tracing::{debug, info};
/// Redirect layer that redirects browser requests to the console
#[derive(Clone)]
@@ -89,3 +94,173 @@ where
Box::pin(async move { inner.call(req).await.map_err(Into::into) })
}
}
/// Conditional CORS layer that only applies to S3 API requests
/// (not Admin, not Console, not RPC)
#[derive(Clone)]
pub struct ConditionalCorsLayer {
cors_origins: Option<String>,
}
impl ConditionalCorsLayer {
pub fn new() -> Self {
let cors_origins = std::env::var("RUSTFS_CORS_ALLOWED_ORIGINS").ok().filter(|s| !s.is_empty());
Self { cors_origins }
}
/// Exact paths that should be excluded from being treated as S3 paths.
const EXCLUDED_EXACT_PATHS: &'static [&'static str] = &["/health", "/profile/cpu", "/profile/memory"];
fn is_s3_path(path: &str) -> bool {
// Exclude Admin, Console, RPC, and configured special paths
!path.starts_with(ADMIN_PREFIX)
&& !path.starts_with(RPC_PREFIX)
&& !is_console_path(path)
&& !Self::EXCLUDED_EXACT_PATHS.contains(&path)
}
fn apply_cors_headers(&self, request_headers: &HeaderMap, response_headers: &mut HeaderMap) {
let origin = request_headers
.get(cors::standard::ORIGIN)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let allowed_origin = match (origin, &self.cors_origins) {
(Some(orig), Some(config)) if config == "*" => Some(orig),
(Some(orig), Some(config)) => {
let origins: Vec<&str> = config.split(',').map(|s| s.trim()).collect();
if origins.contains(&orig.as_str()) { Some(orig) } else { None }
}
(Some(orig), None) => Some(orig), // Default: allow all if not configured
_ => None,
};
// Track whether we're using a specific origin (not wildcard)
let using_specific_origin = if let Some(origin) = &allowed_origin {
if let Ok(header_value) = HeaderValue::from_str(origin) {
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, header_value);
true // Using specific origin, credentials allowed
} else {
false
}
} else {
false
};
// Allow all methods by default (S3-compatible set)
response_headers.insert(
cors::response::ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_static("GET, POST, PUT, DELETE, OPTIONS, HEAD"),
);
// Allow all headers by default
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_static("*"));
// Expose common headers
response_headers.insert(
cors::response::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::from_static("x-request-id, content-type, content-length, etag"),
);
// Only set credentials when using a specific origin (not wildcard)
// CORS spec: credentials cannot be used with wildcard origins
if using_specific_origin {
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
}
}
}
impl Default for ConditionalCorsLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for ConditionalCorsLayer {
type Service = ConditionalCorsService<S>;
fn layer(&self, inner: S) -> Self::Service {
ConditionalCorsService {
inner,
cors_origins: Arc::new(self.cors_origins.clone()),
}
}
}
/// Service implementation for conditional CORS
#[derive(Clone)]
pub struct ConditionalCorsService<S> {
inner: S,
cors_origins: Arc<Option<String>>,
}
impl<S, ResBody> Service<HttpRequest<Incoming>> for ConditionalCorsService<S>
where
S: Service<HttpRequest<Incoming>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
ResBody: Default + Send + 'static,
{
type Response = Response<ResBody>;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: HttpRequest<Incoming>) -> Self::Future {
let path = req.uri().path().to_string();
let method = req.method().clone();
let request_headers = req.headers().clone();
let cors_origins = self.cors_origins.clone();
// Handle OPTIONS preflight requests - return response directly without calling handler
if method == Method::OPTIONS && request_headers.contains_key(cors::standard::ORIGIN) {
info!("OPTIONS preflight request for path: {}", path);
let path_trimmed = path.trim_start_matches('/');
let bucket = path_trimmed.split('/').next().unwrap_or("").to_string(); // virtual host style?
let method_clone = method.clone();
let request_headers_clone = request_headers.clone();
return Box::pin(async move {
let mut response = Response::builder().status(StatusCode::OK).body(ResBody::default()).unwrap();
if ConditionalCorsLayer::is_s3_path(&path)
&& !bucket.is_empty()
&& cors_origins.is_some()
&& let Some(cors_headers) = ecfs::apply_cors_headers(&bucket, &method_clone, &request_headers_clone).await
{
for (key, value) in cors_headers.iter() {
response.headers_mut().insert(key, value.clone());
}
return Ok(response);
}
let cors_layer = ConditionalCorsLayer {
cors_origins: (*cors_origins).clone(),
};
cors_layer.apply_cors_headers(&request_headers_clone, response.headers_mut());
Ok(response)
});
}
let mut inner = self.inner.clone();
Box::pin(async move {
let mut response = inner.call(req).await.map_err(Into::into)?;
// Apply CORS headers only to S3 API requests (non-OPTIONS)
if request_headers.contains_key(cors::standard::ORIGIN)
&& !response.headers().contains_key(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN)
{
let cors_layer = ConditionalCorsLayer {
cors_origins: (*cors_origins).clone(),
};
cors_layer.apply_cors_headers(&request_headers, response.headers_mut());
}
Ok(response)
})
}
}

View File

@@ -15,6 +15,7 @@
mod audit;
mod cert;
mod compress;
pub mod cors;
mod event;
mod http;
mod hybrid;

View File

@@ -342,7 +342,7 @@ impl S3Access for FS {
let req_info = req.extensions.get_mut::<ReqInfo>().expect("ReqInfo not found");
req_info.bucket = Some(req.input.bucket.clone());
authorize_request(req, Action::S3Action(S3Action::PutBucketCorsAction)).await
authorize_request(req, Action::S3Action(S3Action::DeleteBucketCorsAction)).await
}
/// Checks whether the DeleteBucketEncryption request has accesses to the resources.

View File

@@ -18,6 +18,7 @@ use crate::config::workload_profiles::{
};
use crate::error::ApiError;
use crate::server::RemoteAddr;
use crate::server::cors;
use crate::storage::concurrency::{
CachedGetObject, ConcurrencyManager, GetObjectGuard, get_concurrency_aware_buffer_size, get_concurrency_manager,
};
@@ -48,8 +49,8 @@ use rustfs_ecstore::{
lifecycle::{self, Lifecycle, TransitionOptions},
},
metadata::{
BUCKET_LIFECYCLE_CONFIG, BUCKET_NOTIFICATION_CONFIG, BUCKET_POLICY_CONFIG, BUCKET_REPLICATION_CONFIG,
BUCKET_SSECONFIG, BUCKET_TAGGING_CONFIG, BUCKET_VERSIONING_CONFIG, OBJECT_LOCK_CONFIG,
BUCKET_CORS_CONFIG, BUCKET_LIFECYCLE_CONFIG, BUCKET_NOTIFICATION_CONFIG, BUCKET_POLICY_CONFIG,
BUCKET_REPLICATION_CONFIG, BUCKET_SSECONFIG, BUCKET_TAGGING_CONFIG, BUCKET_VERSIONING_CONFIG, OBJECT_LOCK_CONFIG,
},
metadata_sys,
metadata_sys::get_replication_config,
@@ -116,10 +117,9 @@ use rustfs_utils::{
AMZ_BUCKET_REPLICATION_STATUS, AMZ_CHECKSUM_MODE, AMZ_CHECKSUM_TYPE,
headers::{
AMZ_DECODED_CONTENT_LENGTH, AMZ_OBJECT_TAGGING, AMZ_RESTORE_EXPIRY_DAYS, AMZ_RESTORE_REQUEST_DATE,
RESERVED_METADATA_PREFIX_LOWER,
RESERVED_METADATA_PREFIX, RESERVED_METADATA_PREFIX_LOWER,
},
},
obj::extract_user_defined_metadata,
path::{is_dir_object, path_join_buf},
};
use rustfs_zip::CompressionFormat;
@@ -818,6 +818,205 @@ async fn get_validated_store(bucket: &str) -> S3Result<Arc<rustfs_ecstore::store
Ok(store)
}
/// Quick check if CORS processing is needed (lightweight check for Origin header)
/// This avoids unnecessary function calls for non-CORS requests
#[inline]
fn needs_cors_processing(headers: &HeaderMap) -> bool {
headers.contains_key(cors::standard::ORIGIN)
}
/// Apply CORS headers to response based on bucket CORS configuration and request origin
///
/// This function:
/// 1. Reads the Origin header from the request
/// 2. Retrieves the bucket's CORS configuration
/// 3. Matches the origin against CORS rules
/// 4. Validates AllowedHeaders if request headers are present
/// 5. Returns headers to add to the response if a match is found
///
/// Note: This function should only be called if `needs_cors_processing()` returns true
/// to avoid unnecessary overhead for non-CORS requests.
pub(crate) async fn apply_cors_headers(bucket: &str, method: &http::Method, headers: &HeaderMap) -> Option<HeaderMap> {
use http::HeaderValue;
// Get Origin header from request
let origin = headers.get(cors::standard::ORIGIN)?.to_str().ok()?;
// Get CORS configuration for the bucket
let cors_config = match metadata_sys::get_cors_config(bucket).await {
Ok((config, _)) => config,
Err(_) => return None, // No CORS config, no headers to add
};
// Early return if no CORS rules configured
if cors_config.cors_rules.is_empty() {
return None;
}
// Check if method is supported and get its string representation
const SUPPORTED_METHODS: &[&str] = &["GET", "PUT", "POST", "DELETE", "HEAD", "OPTIONS"];
let method_str = method.as_str();
if !SUPPORTED_METHODS.contains(&method_str) {
return None;
}
// For OPTIONS (preflight) requests, check Access-Control-Request-Method
let is_preflight = method == http::Method::OPTIONS;
let requested_method = if is_preflight {
headers
.get(cors::request::ACCESS_CONTROL_REQUEST_METHOD)
.and_then(|v| v.to_str().ok())
.unwrap_or(method_str)
} else {
method_str
};
// Get requested headers from preflight request
let requested_headers = if is_preflight {
headers
.get(cors::request::ACCESS_CONTROL_REQUEST_HEADERS)
.and_then(|v| v.to_str().ok())
.map(|h| h.split(',').map(|s| s.trim().to_lowercase()).collect::<Vec<_>>())
} else {
None
};
// Find matching CORS rule
for rule in cors_config.cors_rules.iter() {
// Check if origin matches
let origin_matches = rule.allowed_origins.iter().any(|allowed_origin| {
if allowed_origin == "*" {
true
} else {
// Exact match or pattern match (support wildcards like https://*.example.com)
allowed_origin == origin || matches_origin_pattern(allowed_origin, origin)
}
});
if !origin_matches {
continue;
}
// Check if method is allowed
let method_allowed = rule
.allowed_methods
.iter()
.any(|allowed_method| allowed_method.as_str() == requested_method);
if !method_allowed {
continue;
}
// Validate AllowedHeaders if present in the request
if let Some(ref req_headers) = requested_headers {
if let Some(ref allowed_headers) = rule.allowed_headers {
// Check if all requested headers are allowed
let all_headers_allowed = req_headers.iter().all(|req_header| {
allowed_headers.iter().any(|allowed_header| {
let allowed_lower = allowed_header.to_lowercase();
// "*" allows all headers, or exact match
allowed_lower == "*" || allowed_lower == *req_header
})
});
if !all_headers_allowed {
// If not all headers are allowed, skip this rule
continue;
}
} else if !req_headers.is_empty() {
// If no AllowedHeaders specified but headers were requested, skip this rule
// Unless the rule explicitly allows all headers
continue;
}
}
// Found matching rule, build response headers
let mut response_headers = HeaderMap::new();
// Access-Control-Allow-Origin
// If origin is "*", use "*", otherwise echo back the origin
let has_wildcard_origin = rule.allowed_origins.iter().any(|o| o == "*");
if has_wildcard_origin {
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
} else if let Ok(origin_value) = HeaderValue::from_str(origin) {
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, origin_value);
}
// Vary: Origin (required for caching, except when using wildcard)
if !has_wildcard_origin {
response_headers.insert(cors::standard::VARY, HeaderValue::from_static("Origin"));
}
// Access-Control-Allow-Methods (required for preflight)
if is_preflight || !rule.allowed_methods.is_empty() {
let methods_str = rule.allowed_methods.iter().map(|m| m.as_str()).collect::<Vec<_>>().join(", ");
if let Ok(methods_value) = HeaderValue::from_str(&methods_str) {
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_METHODS, methods_value);
}
}
// Access-Control-Allow-Headers (required for preflight if headers were requested)
if is_preflight && let Some(ref allowed_headers) = rule.allowed_headers {
let headers_str = allowed_headers.iter().map(|h| h.as_str()).collect::<Vec<_>>().join(", ");
if let Ok(headers_value) = HeaderValue::from_str(&headers_str) {
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_HEADERS, headers_value);
}
}
// Access-Control-Expose-Headers (for actual requests)
if !is_preflight && let Some(ref expose_headers) = rule.expose_headers {
let expose_headers_str = expose_headers.iter().map(|h| h.as_str()).collect::<Vec<_>>().join(", ");
if let Ok(expose_value) = HeaderValue::from_str(&expose_headers_str) {
response_headers.insert(cors::response::ACCESS_CONTROL_EXPOSE_HEADERS, expose_value);
}
}
// Access-Control-Max-Age (for preflight requests)
if is_preflight
&& let Some(max_age) = rule.max_age_seconds
&& let Ok(max_age_value) = HeaderValue::from_str(&max_age.to_string())
{
response_headers.insert(cors::response::ACCESS_CONTROL_MAX_AGE, max_age_value);
}
return Some(response_headers);
}
None // No matching rule found
}
/// Check if an origin matches a pattern (supports wildcards like https://*.example.com)
fn matches_origin_pattern(pattern: &str, origin: &str) -> bool {
// Simple wildcard matching: * matches any sequence
if pattern.contains('*') {
let pattern_parts: Vec<&str> = pattern.split('*').collect();
if pattern_parts.len() == 2 {
origin.starts_with(pattern_parts[0]) && origin.ends_with(pattern_parts[1])
} else {
false
}
} else {
pattern == origin
}
}
/// Wrap S3Response with CORS headers if needed
/// This function performs a lightweight check first to avoid unnecessary CORS processing
/// for non-CORS requests (requests without Origin header)
async fn wrap_response_with_cors<T>(bucket: &str, method: &http::Method, headers: &HeaderMap, output: T) -> S3Response<T> {
let mut response = S3Response::new(output);
// Quick check: only process CORS if Origin header is present
if needs_cors_processing(headers)
&& let Some(cors_headers) = apply_cors_headers(bucket, method, headers).await
{
for (key, value) in cors_headers.iter() {
response.headers.insert(key, value.clone());
}
}
response
}
#[async_trait::async_trait]
impl S3 for FS {
#[instrument(
@@ -875,6 +1074,7 @@ impl S3 for FS {
metadata,
copy_source_if_match,
copy_source_if_none_match,
content_type,
..
} = req.input.clone();
let (src_bucket, src_key, version_id) = match copy_source {
@@ -890,6 +1090,19 @@ impl S3 for FS {
validate_object_key(&src_key, "COPY (source)")?;
validate_object_key(&key, "COPY (dest)")?;
// AWS S3 allows self-copy when metadata directive is REPLACE (used to update metadata in-place).
// Reject only when the directive is not REPLACE.
if metadata_directive.as_ref().map(|d| d.as_str()) != Some(MetadataDirective::REPLACE)
&& src_bucket == bucket
&& src_key == key
{
error!("Rejected self-copy operation: bucket={}, key={}", bucket, key);
return Err(s3_error!(
InvalidRequest,
"Cannot copy an object to itself. Source and destination must be different."
));
}
// warn!("copy_object {}/{}, to {}/{}", &src_bucket, &src_key, &bucket, &key);
let mut src_opts = copy_src_opts(&src_bucket, &src_key, &req.headers).map_err(ApiError::from)?;
@@ -1014,12 +1227,35 @@ impl S3 for FS {
src_info
.user_defined
.remove(&format!("{RESERVED_METADATA_PREFIX_LOWER}compression"));
src_info
.user_defined
.remove(&format!("{RESERVED_METADATA_PREFIX}compression"));
src_info
.user_defined
.remove(&format!("{RESERVED_METADATA_PREFIX_LOWER}actual-size"));
src_info
.user_defined
.remove(&format!("{RESERVED_METADATA_PREFIX}actual-size"));
src_info
.user_defined
.remove(&format!("{RESERVED_METADATA_PREFIX_LOWER}compression-size"));
src_info
.user_defined
.remove(&format!("{RESERVED_METADATA_PREFIX}compression-size"));
}
// Handle MetadataDirective REPLACE: replace user metadata while preserving system metadata.
// System metadata (compression, encryption) is added after this block to ensure
// it's not cleared by the REPLACE operation.
if metadata_directive.as_ref().map(|d| d.as_str()) == Some(MetadataDirective::REPLACE) {
src_info.user_defined.clear();
if let Some(metadata) = metadata {
src_info.user_defined.extend(metadata);
}
if let Some(ct) = content_type {
src_info.content_type = Some(ct.clone());
src_info.user_defined.insert("content-type".to_string(), ct);
}
}
let mut reader = HashReader::new(reader, length, actual_size, None, None, false).map_err(ApiError::from)?;
@@ -1104,16 +1340,6 @@ impl S3 for FS {
.insert("x-amz-server-side-encryption-customer-key-md5".to_string(), sse_md5.clone());
}
if metadata_directive.as_ref().map(|d| d.as_str()) == Some(MetadataDirective::REPLACE) {
let src_user_defined = extract_user_defined_metadata(&src_info.user_defined);
src_user_defined.keys().for_each(|k| {
src_info.user_defined.remove(k);
});
if let Some(metadata) = metadata {
src_info.user_defined.extend(metadata);
}
}
// check quota for copy operation
if let Some(metadata_sys) = rustfs_ecstore::bucket::metadata_sys::GLOBAL_BucketMetadataSys.get() {
let quota_checker = QuotaChecker::new(metadata_sys.clone());
@@ -2598,7 +2824,8 @@ impl S3 for FS {
cache_key, response_content_length, total_duration, optimal_buffer_size
);
let result = Ok(S3Response::new(output));
let response = wrap_response_with_cors(&bucket, &req.method, &req.headers, output).await;
let result = Ok(response);
let _ = helper.complete(&result);
result
}
@@ -2843,7 +3070,14 @@ impl S3 for FS {
let version_id = req.input.version_id.clone().unwrap_or_default();
helper = helper.object(event_info).version_id(version_id);
let result = Ok(S3Response::new(output));
// NOTE ON CORS:
// Bucket-level CORS headers are intentionally applied only for object retrieval
// operations (GET/HEAD) via `wrap_response_with_cors`. Other S3 operations that
// interact with objects (PUT/POST/DELETE/LIST, etc.) rely on the system-level
// CORS layer instead. In case both are applicable, this bucket-level CORS logic
// takes precedence for these read operations.
let response = wrap_response_with_cors(&bucket, &req.method, &req.headers, output).await;
let result = Ok(response);
let _ = helper.complete(&result);
result
@@ -4713,6 +4947,82 @@ impl S3 for FS {
Ok(S3Response::new(DeleteBucketTaggingOutput {}))
}
#[instrument(level = "debug", skip(self))]
async fn get_bucket_cors(&self, req: S3Request<GetBucketCorsInput>) -> S3Result<S3Response<GetBucketCorsOutput>> {
let bucket = req.input.bucket.clone();
// check bucket exists.
let _bucket = self
.head_bucket(req.map_input(|input| HeadBucketInput {
bucket: input.bucket,
expected_bucket_owner: None,
}))
.await?;
let cors_configuration = match metadata_sys::get_cors_config(&bucket).await {
Ok((config, _)) => config,
Err(err) => {
if err == StorageError::ConfigNotFound {
return Err(S3Error::with_message(
S3ErrorCode::NoSuchCORSConfiguration,
"The CORS configuration does not exist".to_string(),
));
}
warn!("get_cors_config err {:?}", &err);
return Err(ApiError::from(err).into());
}
};
Ok(S3Response::new(GetBucketCorsOutput {
cors_rules: Some(cors_configuration.cors_rules),
}))
}
#[instrument(level = "debug", skip(self))]
async fn put_bucket_cors(&self, req: S3Request<PutBucketCorsInput>) -> S3Result<S3Response<PutBucketCorsOutput>> {
let PutBucketCorsInput {
bucket,
cors_configuration,
..
} = req.input;
let Some(store) = new_object_layer_fn() else {
return Err(S3Error::with_message(S3ErrorCode::InternalError, "Not init".to_string()));
};
store
.get_bucket_info(&bucket, &BucketOptions::default())
.await
.map_err(ApiError::from)?;
let data = try_!(serialize(&cors_configuration));
metadata_sys::update(&bucket, BUCKET_CORS_CONFIG, data)
.await
.map_err(ApiError::from)?;
Ok(S3Response::new(PutBucketCorsOutput::default()))
}
#[instrument(level = "debug", skip(self))]
async fn delete_bucket_cors(&self, req: S3Request<DeleteBucketCorsInput>) -> S3Result<S3Response<DeleteBucketCorsOutput>> {
let DeleteBucketCorsInput { bucket, .. } = req.input;
let Some(store) = new_object_layer_fn() else {
return Err(S3Error::with_message(S3ErrorCode::InternalError, "Not init".to_string()));
};
store
.get_bucket_info(&bucket, &BucketOptions::default())
.await
.map_err(ApiError::from)?;
metadata_sys::delete(&bucket, BUCKET_CORS_CONFIG)
.await
.map_err(ApiError::from)?;
Ok(S3Response::new(DeleteBucketCorsOutput {}))
}
#[instrument(level = "debug", skip(self, req))]
async fn put_object_tagging(&self, req: S3Request<PutObjectTaggingInput>) -> S3Result<S3Response<PutObjectTaggingOutput>> {
let start_time = std::time::Instant::now();
@@ -6557,4 +6867,201 @@ mod tests {
assert!(filtered_version_marker.is_some());
assert_eq!(filtered_version_marker.unwrap(), "null");
}
#[test]
fn test_matches_origin_pattern_exact_match() {
// Test exact match
assert!(matches_origin_pattern("https://example.com", "https://example.com"));
assert!(matches_origin_pattern("http://localhost:3000", "http://localhost:3000"));
assert!(!matches_origin_pattern("https://example.com", "https://other.com"));
}
#[test]
fn test_matches_origin_pattern_wildcard() {
// Test wildcard pattern matching (S3 CORS supports * as subdomain wildcard)
assert!(matches_origin_pattern("https://*.example.com", "https://app.example.com"));
assert!(matches_origin_pattern("https://*.example.com", "https://api.example.com"));
assert!(matches_origin_pattern("https://*.example.com", "https://subdomain.example.com"));
// Test wildcard at start (matches any domain)
assert!(matches_origin_pattern("https://*", "https://example.com"));
assert!(matches_origin_pattern("https://*", "https://any-domain.com"));
// Test wildcard at end (matches any protocol)
assert!(matches_origin_pattern("*://example.com", "https://example.com"));
assert!(matches_origin_pattern("*://example.com", "http://example.com"));
// Test invalid wildcard patterns (should not match)
assert!(!matches_origin_pattern("https://*.*.com", "https://app.example.com")); // Multiple wildcards (invalid pattern)
// Note: "https://*example.com" actually matches "https://app.example.com" with our current implementation
// because it splits on * and checks starts_with/ends_with. This is a limitation but acceptable
// for S3 CORS which typically uses patterns like "https://*.example.com"
}
#[test]
fn test_matches_origin_pattern_no_wildcard() {
// Test patterns without wildcards
assert!(matches_origin_pattern("https://example.com", "https://example.com"));
assert!(!matches_origin_pattern("https://example.com", "https://example.org"));
assert!(!matches_origin_pattern("http://example.com", "https://example.com")); // Different protocol
}
#[test]
fn test_matches_origin_pattern_edge_cases() {
// Test edge cases
assert!(!matches_origin_pattern("", "https://example.com")); // Empty pattern
assert!(!matches_origin_pattern("https://example.com", "")); // Empty origin
assert!(matches_origin_pattern("", "")); // Both empty
assert!(!matches_origin_pattern("https://example.com", "http://example.com")); // Protocol mismatch
}
#[test]
fn test_cors_headers_validation() {
use http::HeaderMap;
// Test case 1: Validate header name case-insensitivity
let mut headers = HeaderMap::new();
headers.insert("access-control-request-headers", "Content-Type,X-Custom-Header".parse().unwrap());
let req_headers_str = headers
.get("access-control-request-headers")
.and_then(|v| v.to_str().ok())
.unwrap();
let req_headers: Vec<String> = req_headers_str.split(',').map(|s| s.trim().to_lowercase()).collect();
// Headers should be lowercased for comparison
assert_eq!(req_headers, vec!["content-type", "x-custom-header"]);
// Test case 2: Wildcard matching
let allowed_headers = ["*".to_string()];
let all_allowed = req_headers.iter().all(|req_header| {
allowed_headers
.iter()
.any(|allowed| allowed.to_lowercase() == "*" || allowed.to_lowercase() == *req_header)
});
assert!(all_allowed, "Wildcard should allow all headers");
// Test case 3: Specific header matching
let allowed_headers = ["content-type".to_string(), "x-custom-header".to_string()];
let all_allowed = req_headers
.iter()
.all(|req_header| allowed_headers.iter().any(|allowed| allowed.to_lowercase() == *req_header));
assert!(all_allowed, "All requested headers should be allowed");
// Test case 4: Disallowed header
let req_headers = ["content-type".to_string(), "x-forbidden-header".to_string()];
let allowed_headers = ["content-type".to_string()];
let all_allowed = req_headers
.iter()
.all(|req_header| allowed_headers.iter().any(|allowed| allowed.to_lowercase() == *req_header));
assert!(!all_allowed, "Forbidden header should not be allowed");
}
#[test]
fn test_cors_response_headers_structure() {
use http::{HeaderMap, HeaderValue};
let mut cors_headers = HeaderMap::new();
// Simulate building CORS response headers
let origin = "https://example.com";
let methods = ["GET", "PUT", "POST"];
let allowed_headers = ["Content-Type", "Authorization"];
let expose_headers = ["ETag", "x-amz-version-id"];
let max_age = 3600;
// Add headers
cors_headers.insert("access-control-allow-origin", HeaderValue::from_str(origin).unwrap());
cors_headers.insert("vary", HeaderValue::from_static("Origin"));
let methods_str = methods.join(", ");
cors_headers.insert("access-control-allow-methods", HeaderValue::from_str(&methods_str).unwrap());
let headers_str = allowed_headers.join(", ");
cors_headers.insert("access-control-allow-headers", HeaderValue::from_str(&headers_str).unwrap());
let expose_str = expose_headers.join(", ");
cors_headers.insert("access-control-expose-headers", HeaderValue::from_str(&expose_str).unwrap());
cors_headers.insert("access-control-max-age", HeaderValue::from_str(&max_age.to_string()).unwrap());
// Verify all headers are present
assert_eq!(cors_headers.get("access-control-allow-origin").unwrap(), origin);
assert_eq!(cors_headers.get("vary").unwrap(), "Origin");
assert_eq!(cors_headers.get("access-control-allow-methods").unwrap(), "GET, PUT, POST");
assert_eq!(cors_headers.get("access-control-allow-headers").unwrap(), "Content-Type, Authorization");
assert_eq!(cors_headers.get("access-control-expose-headers").unwrap(), "ETag, x-amz-version-id");
assert_eq!(cors_headers.get("access-control-max-age").unwrap(), "3600");
}
#[test]
fn test_cors_preflight_vs_actual_request() {
use http::Method;
// Test that we can distinguish preflight from actual requests
let preflight_method = Method::OPTIONS;
let actual_method = Method::PUT;
assert_eq!(preflight_method, Method::OPTIONS);
assert_ne!(actual_method, Method::OPTIONS);
// Preflight should check Access-Control-Request-Method
// Actual request should use the actual method
let is_preflight_1 = preflight_method == Method::OPTIONS;
let is_preflight_2 = actual_method == Method::OPTIONS;
assert!(is_preflight_1);
assert!(!is_preflight_2);
}
#[tokio::test]
async fn test_apply_cors_headers_no_origin() {
// Test when no Origin header is present
let headers = HeaderMap::new();
let method = http::Method::GET;
// Should return None when no origin header
let result = apply_cors_headers("test-bucket", &method, &headers).await;
assert!(result.is_none(), "Should return None when no Origin header");
}
#[tokio::test]
async fn test_apply_cors_headers_no_cors_config() {
// Test when bucket has no CORS configuration
let mut headers = HeaderMap::new();
headers.insert("origin", "https://example.com".parse().unwrap());
let method = http::Method::GET;
// Should return None when no CORS config exists
// Note: This test may fail if test-bucket actually has CORS config
// In a real scenario, we'd use a mock or ensure the bucket doesn't exist
let _result = apply_cors_headers("non-existent-bucket-for-testing", &method, &headers).await;
// Result depends on whether bucket exists and has CORS config
// This is expected behavior - we just verify it doesn't panic
}
#[tokio::test]
async fn test_apply_cors_headers_unsupported_method() {
// Test with unsupported HTTP method
let mut headers = HeaderMap::new();
headers.insert("origin", "https://example.com".parse().unwrap());
let method = http::Method::PATCH; // Unsupported method
let result = apply_cors_headers("test-bucket", &method, &headers).await;
assert!(result.is_none(), "Should return None for unsupported methods");
}
#[test]
fn test_matches_origin_pattern_complex_wildcards() {
// Test more complex wildcard scenarios
assert!(matches_origin_pattern("https://*.example.com", "https://sub.example.com"));
// Note: "https://*.example.com" matches "https://api.sub.example.com" with our implementation
// because it only checks starts_with and ends_with. Real S3 might be more strict.
// Test wildcard in middle position
// Our implementation allows this, but it's not standard S3 CORS pattern
// The pattern "https://example.*.com" splits to ["https://example.", ".com"]
// and "https://example.sub.com" matches because it starts with "https://example." and ends with ".com"
// This is acceptable for our use case as S3 CORS typically uses "https://*.example.com" format
}
}