mirror of
https://github.com/rustfs/rustfs.git
synced 2026-01-16 17:20:33 +00:00
feat: add Cors (#1496)
Signed-off-by: GatewayJ <835269233@qq.com> Co-authored-by: loverustfs <hello@rustfs.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: houseme <housemecn@gmail.com>
This commit is contained in:
@@ -25,8 +25,8 @@ use byteorder::{BigEndian, ByteOrder, LittleEndian};
|
||||
use rmp_serde::Serializer as rmpSerializer;
|
||||
use rustfs_policy::policy::BucketPolicy;
|
||||
use s3s::dto::{
|
||||
BucketLifecycleConfiguration, NotificationConfiguration, ObjectLockConfiguration, ReplicationConfiguration,
|
||||
ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration,
|
||||
BucketLifecycleConfiguration, CORSConfiguration, NotificationConfiguration, ObjectLockConfiguration,
|
||||
ReplicationConfiguration, ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration,
|
||||
};
|
||||
use serde::Serializer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -49,6 +49,7 @@ pub const OBJECT_LOCK_CONFIG: &str = "object-lock.xml";
|
||||
pub const BUCKET_VERSIONING_CONFIG: &str = "versioning.xml";
|
||||
pub const BUCKET_REPLICATION_CONFIG: &str = "replication.xml";
|
||||
pub const BUCKET_TARGETS_FILE: &str = "bucket-targets.json";
|
||||
pub const BUCKET_CORS_CONFIG: &str = "cors.xml";
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
#[serde(rename_all = "PascalCase", default)]
|
||||
@@ -67,6 +68,7 @@ pub struct BucketMetadata {
|
||||
pub replication_config_xml: Vec<u8>,
|
||||
pub bucket_targets_config_json: Vec<u8>,
|
||||
pub bucket_targets_config_meta_json: Vec<u8>,
|
||||
pub cors_config_xml: Vec<u8>,
|
||||
|
||||
pub policy_config_updated_at: OffsetDateTime,
|
||||
pub object_lock_config_updated_at: OffsetDateTime,
|
||||
@@ -79,6 +81,7 @@ pub struct BucketMetadata {
|
||||
pub notification_config_updated_at: OffsetDateTime,
|
||||
pub bucket_targets_config_updated_at: OffsetDateTime,
|
||||
pub bucket_targets_config_meta_updated_at: OffsetDateTime,
|
||||
pub cors_config_updated_at: OffsetDateTime,
|
||||
|
||||
#[serde(skip)]
|
||||
pub new_field_updated_at: OffsetDateTime,
|
||||
@@ -105,6 +108,8 @@ pub struct BucketMetadata {
|
||||
pub bucket_target_config: Option<BucketTargets>,
|
||||
#[serde(skip)]
|
||||
pub bucket_target_config_meta: Option<HashMap<String, String>>,
|
||||
#[serde(skip)]
|
||||
pub cors_config: Option<CORSConfiguration>,
|
||||
}
|
||||
|
||||
impl Default for BucketMetadata {
|
||||
@@ -124,6 +129,7 @@ impl Default for BucketMetadata {
|
||||
replication_config_xml: Default::default(),
|
||||
bucket_targets_config_json: Default::default(),
|
||||
bucket_targets_config_meta_json: Default::default(),
|
||||
cors_config_xml: Default::default(),
|
||||
policy_config_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
object_lock_config_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
encryption_config_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
@@ -135,6 +141,7 @@ impl Default for BucketMetadata {
|
||||
notification_config_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
bucket_targets_config_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
bucket_targets_config_meta_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
cors_config_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
new_field_updated_at: OffsetDateTime::UNIX_EPOCH,
|
||||
policy_config: Default::default(),
|
||||
notification_config: Default::default(),
|
||||
@@ -147,6 +154,7 @@ impl Default for BucketMetadata {
|
||||
replication_config: Default::default(),
|
||||
bucket_target_config: Default::default(),
|
||||
bucket_target_config_meta: Default::default(),
|
||||
cors_config: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -295,6 +303,10 @@ impl BucketMetadata {
|
||||
self.bucket_targets_config_json = data.clone();
|
||||
self.bucket_targets_config_updated_at = updated;
|
||||
}
|
||||
BUCKET_CORS_CONFIG => {
|
||||
self.cors_config_xml = data;
|
||||
self.cors_config_updated_at = updated;
|
||||
}
|
||||
_ => return Err(Error::other(format!("config file not found : {config_file}"))),
|
||||
}
|
||||
|
||||
@@ -365,6 +377,9 @@ impl BucketMetadata {
|
||||
} else {
|
||||
self.bucket_target_config = Some(BucketTargets::default())
|
||||
}
|
||||
if !self.cors_config_xml.is_empty() {
|
||||
self.cors_config = Some(deserialize::<CORSConfiguration>(&self.cors_config_xml)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ use rustfs_common::heal_channel::HealOpts;
|
||||
use rustfs_policy::policy::BucketPolicy;
|
||||
use s3s::dto::ReplicationConfiguration;
|
||||
use s3s::dto::{
|
||||
BucketLifecycleConfiguration, NotificationConfiguration, ObjectLockConfiguration, ServerSideEncryptionConfiguration, Tagging,
|
||||
VersioningConfiguration,
|
||||
BucketLifecycleConfiguration, CORSConfiguration, NotificationConfiguration, ObjectLockConfiguration,
|
||||
ServerSideEncryptionConfiguration, Tagging, VersioningConfiguration,
|
||||
};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::OnceLock;
|
||||
@@ -110,6 +110,13 @@ pub async fn get_bucket_targets_config(bucket: &str) -> Result<BucketTargets> {
|
||||
bucket_meta_sys.get_bucket_targets_config(bucket).await
|
||||
}
|
||||
|
||||
pub async fn get_cors_config(bucket: &str) -> Result<(CORSConfiguration, OffsetDateTime)> {
|
||||
let bucket_meta_sys_lock = get_bucket_metadata_sys()?;
|
||||
let bucket_meta_sys = bucket_meta_sys_lock.read().await;
|
||||
|
||||
bucket_meta_sys.get_cors_config(bucket).await
|
||||
}
|
||||
|
||||
pub async fn get_tagging_config(bucket: &str) -> Result<(Tagging, OffsetDateTime)> {
|
||||
let bucket_meta_sys_lock = get_bucket_metadata_sys()?;
|
||||
let bucket_meta_sys = bucket_meta_sys_lock.read().await;
|
||||
@@ -500,6 +507,16 @@ impl BucketMetadataSys {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_cors_config(&self, bucket: &str) -> Result<(CORSConfiguration, OffsetDateTime)> {
|
||||
let (bm, _) = self.get_config(bucket).await?;
|
||||
|
||||
if let Some(config) = &bm.cors_config {
|
||||
Ok((config.clone(), bm.cors_config_updated_at))
|
||||
} else {
|
||||
Err(Error::ConfigNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn created_at(&self, bucket: &str) -> Result<OffsetDateTime> {
|
||||
let bm = match self.get_config(bucket).await {
|
||||
Ok((bm, _)) => bm.created,
|
||||
|
||||
@@ -84,6 +84,7 @@ where
|
||||
{
|
||||
fn is_match(&self, method: &Method, uri: &Uri, headers: &HeaderMap, _: &mut Extensions) -> bool {
|
||||
let path = uri.path();
|
||||
|
||||
// Profiling endpoints
|
||||
if method == Method::GET && (path == PROFILE_CPU_PATH || path == PROFILE_MEMORY_PATH) {
|
||||
return true;
|
||||
@@ -150,6 +151,8 @@ where
|
||||
}
|
||||
|
||||
async fn call(&self, req: S3Request<Body>) -> S3Result<S3Response<Body>> {
|
||||
// Console requests should be handled by console router first (including OPTIONS)
|
||||
// Console has its own CORS layer configured
|
||||
if self.console_enabled && is_console_path(req.uri.path()) {
|
||||
if let Some(console_router) = &self.console_router {
|
||||
let mut console_router = console_router.clone();
|
||||
@@ -164,11 +167,14 @@ where
|
||||
}
|
||||
|
||||
let uri = format!("{}|{}", &req.method, req.uri.path());
|
||||
|
||||
if let Ok(mat) = self.router.at(&uri) {
|
||||
let op: &T = mat.value;
|
||||
let mut resp = op.call(req, mat.params).await?;
|
||||
resp.status = Some(resp.output.0);
|
||||
return Ok(resp.map_output(|x| x.1));
|
||||
let response = resp.map_output(|x| x.1);
|
||||
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
Err(s3_error!(NotImplemented))
|
||||
|
||||
40
rustfs/src/server/cors.rs
Normal file
40
rustfs/src/server/cors.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright 2024 RustFS Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! CORS (Cross-Origin Resource Sharing) header name constants.
|
||||
//!
|
||||
//! This module provides centralized constants for CORS-related HTTP header names.
|
||||
//! The http crate doesn't provide pre-defined constants for CORS headers,
|
||||
//! so we define them here for type safety and maintainability.
|
||||
|
||||
/// CORS response header names
|
||||
pub mod response {
|
||||
pub const ACCESS_CONTROL_ALLOW_ORIGIN: &str = "access-control-allow-origin";
|
||||
pub const ACCESS_CONTROL_ALLOW_METHODS: &str = "access-control-allow-methods";
|
||||
pub const ACCESS_CONTROL_ALLOW_HEADERS: &str = "access-control-allow-headers";
|
||||
pub const ACCESS_CONTROL_EXPOSE_HEADERS: &str = "access-control-expose-headers";
|
||||
pub const ACCESS_CONTROL_ALLOW_CREDENTIALS: &str = "access-control-allow-credentials";
|
||||
pub const ACCESS_CONTROL_MAX_AGE: &str = "access-control-max-age";
|
||||
}
|
||||
|
||||
/// CORS request header names
|
||||
pub mod request {
|
||||
pub const ACCESS_CONTROL_REQUEST_METHOD: &str = "access-control-request-method";
|
||||
pub const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers";
|
||||
}
|
||||
|
||||
/// Standard HTTP header names used in CORS processing
|
||||
pub mod standard {
|
||||
pub use http::header::{ORIGIN, VARY};
|
||||
}
|
||||
@@ -17,7 +17,11 @@ use super::compress::{CompressionConfig, CompressionPredicate};
|
||||
use crate::admin;
|
||||
use crate::auth::IAMAuth;
|
||||
use crate::config;
|
||||
use crate::server::{ReadinessGateLayer, RemoteAddr, ServiceState, ServiceStateManager, hybrid::hybrid, layer::RedirectLayer};
|
||||
use crate::server::{
|
||||
ReadinessGateLayer, RemoteAddr, ServiceState, ServiceStateManager,
|
||||
hybrid::hybrid,
|
||||
layer::{ConditionalCorsLayer, RedirectLayer},
|
||||
};
|
||||
use crate::storage;
|
||||
use crate::storage::tonic_service::make_server;
|
||||
use bytes::Bytes;
|
||||
@@ -48,70 +52,10 @@ use tower::ServiceBuilder;
|
||||
use tower_http::add_extension::AddExtensionLayer;
|
||||
use tower_http::catch_panic::CatchPanicLayer;
|
||||
use tower_http::compression::CompressionLayer;
|
||||
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
|
||||
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::{Span, debug, error, info, instrument, warn};
|
||||
|
||||
/// Parse CORS allowed origins from configuration
|
||||
fn parse_cors_origins(origins: Option<&String>) -> CorsLayer {
|
||||
use http::Method;
|
||||
|
||||
let cors_layer = CorsLayer::new()
|
||||
.allow_methods([
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::HEAD,
|
||||
Method::OPTIONS,
|
||||
])
|
||||
.allow_headers(Any);
|
||||
|
||||
match origins {
|
||||
Some(origins_str) if origins_str == "*" => cors_layer.allow_origin(Any).expose_headers(Any),
|
||||
Some(origins_str) => {
|
||||
let origins: Vec<&str> = origins_str.split(',').map(|s| s.trim()).collect();
|
||||
if origins.is_empty() {
|
||||
warn!("Empty CORS origins provided, using permissive CORS");
|
||||
cors_layer.allow_origin(Any).expose_headers(Any)
|
||||
} else {
|
||||
// Parse origins with proper error handling
|
||||
let mut valid_origins = Vec::new();
|
||||
for origin in origins {
|
||||
match origin.parse::<http::HeaderValue>() {
|
||||
Ok(header_value) => {
|
||||
valid_origins.push(header_value);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Invalid CORS origin '{}': {}", origin, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if valid_origins.is_empty() {
|
||||
warn!("No valid CORS origins found, using permissive CORS");
|
||||
cors_layer.allow_origin(Any).expose_headers(Any)
|
||||
} else {
|
||||
info!("Endpoint CORS origins configured: {:?}", valid_origins);
|
||||
cors_layer.allow_origin(AllowOrigin::list(valid_origins)).expose_headers(Any)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
debug!("No CORS origins configured for endpoint, using permissive CORS");
|
||||
cors_layer.allow_origin(Any).expose_headers(Any)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_cors_allowed_origins() -> String {
|
||||
std::env::var(rustfs_config::ENV_CORS_ALLOWED_ORIGINS)
|
||||
.unwrap_or_else(|_| rustfs_config::DEFAULT_CORS_ALLOWED_ORIGINS.to_string())
|
||||
.parse::<String>()
|
||||
.unwrap_or(rustfs_config::DEFAULT_CONSOLE_CORS_ALLOWED_ORIGINS.to_string())
|
||||
}
|
||||
|
||||
pub async fn start_http_server(
|
||||
opt: &config::Opt,
|
||||
worker_state_manager: ServiceStateManager,
|
||||
@@ -273,14 +217,6 @@ pub async fn start_http_server(
|
||||
let (shutdown_tx, mut shutdown_rx) = tokio::sync::broadcast::channel(1);
|
||||
let shutdown_tx_clone = shutdown_tx.clone();
|
||||
|
||||
// Capture CORS configuration for the server loop
|
||||
let cors_allowed_origins = get_cors_allowed_origins();
|
||||
let cors_allowed_origins = if cors_allowed_origins.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(cors_allowed_origins)
|
||||
};
|
||||
|
||||
// Create compression configuration from environment variables
|
||||
let compression_config = CompressionConfig::from_env();
|
||||
if compression_config.enabled {
|
||||
@@ -294,8 +230,10 @@ pub async fn start_http_server(
|
||||
|
||||
let is_console = opt.console_enable;
|
||||
tokio::spawn(async move {
|
||||
// Create CORS layer inside the server loop closure
|
||||
let cors_layer = parse_cors_origins(cors_allowed_origins.as_ref());
|
||||
// Note: CORS layer is removed from global middleware stack
|
||||
// - S3 API CORS is handled by bucket-level CORS configuration in apply_cors_headers()
|
||||
// - Console CORS is handled by its own cors_layer in setup_console_middleware_stack()
|
||||
// This ensures S3 API CORS behavior matches AWS S3 specification
|
||||
|
||||
#[cfg(unix)]
|
||||
let (mut sigterm_inner, mut sigint_inner) = {
|
||||
@@ -405,7 +343,6 @@ pub async fn start_http_server(
|
||||
let connection_ctx = ConnectionContext {
|
||||
http_server: http_server.clone(),
|
||||
s3_service: s3_service.clone(),
|
||||
cors_layer: cors_layer.clone(),
|
||||
compression_config: compression_config.clone(),
|
||||
is_console,
|
||||
readiness: readiness.clone(),
|
||||
@@ -521,7 +458,6 @@ async fn setup_tls_acceptor(tls_path: &str) -> Result<Option<TlsAcceptor>> {
|
||||
struct ConnectionContext {
|
||||
http_server: Arc<ConnBuilder<TokioExecutor>>,
|
||||
s3_service: S3Service,
|
||||
cors_layer: CorsLayer,
|
||||
compression_config: CompressionConfig,
|
||||
is_console: bool,
|
||||
readiness: Arc<GlobalReadiness>,
|
||||
@@ -546,7 +482,6 @@ fn process_connection(
|
||||
let ConnectionContext {
|
||||
http_server,
|
||||
s3_service,
|
||||
cors_layer,
|
||||
compression_config,
|
||||
is_console,
|
||||
readiness,
|
||||
@@ -629,10 +564,15 @@ fn process_connection(
|
||||
}),
|
||||
)
|
||||
.layer(PropagateRequestIdLayer::x_request_id())
|
||||
.layer(cors_layer)
|
||||
// Compress responses based on whitelist configuration
|
||||
// Only compresses when enabled and matches configured extensions/MIME types
|
||||
.layer(CompressionLayer::new().compress_when(CompressionPredicate::new(compression_config)))
|
||||
// Conditional CORS layer: only applies to S3 API requests (not Admin, not Console)
|
||||
// Admin has its own CORS handling in router.rs
|
||||
// Console has its own CORS layer in setup_console_middleware_stack()
|
||||
// S3 API uses this system default CORS (RUSTFS_CORS_ALLOWED_ORIGINS)
|
||||
// Bucket-level CORS takes precedence when configured (handled in router.rs for OPTIONS, and in ecfs.rs for actual requests)
|
||||
.layer(ConditionalCorsLayer::new())
|
||||
.option_layer(if is_console { Some(RedirectLayer) } else { None })
|
||||
.service(service);
|
||||
|
||||
|
||||
@@ -12,14 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::admin::console::is_console_path;
|
||||
use crate::server::cors;
|
||||
use crate::server::hybrid::HybridBody;
|
||||
use http::{Request as HttpRequest, Response, StatusCode};
|
||||
use crate::server::{ADMIN_PREFIX, RPC_PREFIX};
|
||||
use crate::storage::ecfs;
|
||||
use http::{HeaderMap, HeaderValue, Method, Request as HttpRequest, Response, StatusCode};
|
||||
use hyper::body::Incoming;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tower::{Layer, Service};
|
||||
use tracing::debug;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Redirect layer that redirects browser requests to the console
|
||||
#[derive(Clone)]
|
||||
@@ -89,3 +94,173 @@ where
|
||||
Box::pin(async move { inner.call(req).await.map_err(Into::into) })
|
||||
}
|
||||
}
|
||||
|
||||
/// Conditional CORS layer that only applies to S3 API requests
|
||||
/// (not Admin, not Console, not RPC)
|
||||
#[derive(Clone)]
|
||||
pub struct ConditionalCorsLayer {
|
||||
cors_origins: Option<String>,
|
||||
}
|
||||
|
||||
impl ConditionalCorsLayer {
|
||||
pub fn new() -> Self {
|
||||
let cors_origins = std::env::var("RUSTFS_CORS_ALLOWED_ORIGINS").ok().filter(|s| !s.is_empty());
|
||||
Self { cors_origins }
|
||||
}
|
||||
|
||||
/// Exact paths that should be excluded from being treated as S3 paths.
|
||||
const EXCLUDED_EXACT_PATHS: &'static [&'static str] = &["/health", "/profile/cpu", "/profile/memory"];
|
||||
|
||||
fn is_s3_path(path: &str) -> bool {
|
||||
// Exclude Admin, Console, RPC, and configured special paths
|
||||
!path.starts_with(ADMIN_PREFIX)
|
||||
&& !path.starts_with(RPC_PREFIX)
|
||||
&& !is_console_path(path)
|
||||
&& !Self::EXCLUDED_EXACT_PATHS.contains(&path)
|
||||
}
|
||||
|
||||
fn apply_cors_headers(&self, request_headers: &HeaderMap, response_headers: &mut HeaderMap) {
|
||||
let origin = request_headers
|
||||
.get(cors::standard::ORIGIN)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let allowed_origin = match (origin, &self.cors_origins) {
|
||||
(Some(orig), Some(config)) if config == "*" => Some(orig),
|
||||
(Some(orig), Some(config)) => {
|
||||
let origins: Vec<&str> = config.split(',').map(|s| s.trim()).collect();
|
||||
if origins.contains(&orig.as_str()) { Some(orig) } else { None }
|
||||
}
|
||||
(Some(orig), None) => Some(orig), // Default: allow all if not configured
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Track whether we're using a specific origin (not wildcard)
|
||||
let using_specific_origin = if let Some(origin) = &allowed_origin {
|
||||
if let Ok(header_value) = HeaderValue::from_str(origin) {
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, header_value);
|
||||
true // Using specific origin, credentials allowed
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Allow all methods by default (S3-compatible set)
|
||||
response_headers.insert(
|
||||
cors::response::ACCESS_CONTROL_ALLOW_METHODS,
|
||||
HeaderValue::from_static("GET, POST, PUT, DELETE, OPTIONS, HEAD"),
|
||||
);
|
||||
|
||||
// Allow all headers by default
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_static("*"));
|
||||
|
||||
// Expose common headers
|
||||
response_headers.insert(
|
||||
cors::response::ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||
HeaderValue::from_static("x-request-id, content-type, content-length, etag"),
|
||||
);
|
||||
|
||||
// Only set credentials when using a specific origin (not wildcard)
|
||||
// CORS spec: credentials cannot be used with wildcard origins
|
||||
if using_specific_origin {
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConditionalCorsLayer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for ConditionalCorsLayer {
|
||||
type Service = ConditionalCorsService<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
ConditionalCorsService {
|
||||
inner,
|
||||
cors_origins: Arc::new(self.cors_origins.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Service implementation for conditional CORS
|
||||
#[derive(Clone)]
|
||||
pub struct ConditionalCorsService<S> {
|
||||
inner: S,
|
||||
cors_origins: Arc<Option<String>>,
|
||||
}
|
||||
|
||||
impl<S, ResBody> Service<HttpRequest<Incoming>> for ConditionalCorsService<S>
|
||||
where
|
||||
S: Service<HttpRequest<Incoming>, Response = Response<ResBody>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
|
||||
ResBody: Default + Send + 'static,
|
||||
{
|
||||
type Response = Response<ResBody>;
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: HttpRequest<Incoming>) -> Self::Future {
|
||||
let path = req.uri().path().to_string();
|
||||
let method = req.method().clone();
|
||||
let request_headers = req.headers().clone();
|
||||
let cors_origins = self.cors_origins.clone();
|
||||
// Handle OPTIONS preflight requests - return response directly without calling handler
|
||||
if method == Method::OPTIONS && request_headers.contains_key(cors::standard::ORIGIN) {
|
||||
info!("OPTIONS preflight request for path: {}", path);
|
||||
|
||||
let path_trimmed = path.trim_start_matches('/');
|
||||
let bucket = path_trimmed.split('/').next().unwrap_or("").to_string(); // virtual host style?
|
||||
let method_clone = method.clone();
|
||||
let request_headers_clone = request_headers.clone();
|
||||
|
||||
return Box::pin(async move {
|
||||
let mut response = Response::builder().status(StatusCode::OK).body(ResBody::default()).unwrap();
|
||||
|
||||
if ConditionalCorsLayer::is_s3_path(&path)
|
||||
&& !bucket.is_empty()
|
||||
&& cors_origins.is_some()
|
||||
&& let Some(cors_headers) = ecfs::apply_cors_headers(&bucket, &method_clone, &request_headers_clone).await
|
||||
{
|
||||
for (key, value) in cors_headers.iter() {
|
||||
response.headers_mut().insert(key, value.clone());
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
let cors_layer = ConditionalCorsLayer {
|
||||
cors_origins: (*cors_origins).clone(),
|
||||
};
|
||||
cors_layer.apply_cors_headers(&request_headers_clone, response.headers_mut());
|
||||
|
||||
Ok(response)
|
||||
});
|
||||
}
|
||||
|
||||
let mut inner = self.inner.clone();
|
||||
Box::pin(async move {
|
||||
let mut response = inner.call(req).await.map_err(Into::into)?;
|
||||
|
||||
// Apply CORS headers only to S3 API requests (non-OPTIONS)
|
||||
if request_headers.contains_key(cors::standard::ORIGIN)
|
||||
&& !response.headers().contains_key(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN)
|
||||
{
|
||||
let cors_layer = ConditionalCorsLayer {
|
||||
cors_origins: (*cors_origins).clone(),
|
||||
};
|
||||
cors_layer.apply_cors_headers(&request_headers, response.headers_mut());
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
mod audit;
|
||||
mod cert;
|
||||
mod compress;
|
||||
pub mod cors;
|
||||
mod event;
|
||||
mod http;
|
||||
mod hybrid;
|
||||
|
||||
@@ -342,7 +342,7 @@ impl S3Access for FS {
|
||||
let req_info = req.extensions.get_mut::<ReqInfo>().expect("ReqInfo not found");
|
||||
req_info.bucket = Some(req.input.bucket.clone());
|
||||
|
||||
authorize_request(req, Action::S3Action(S3Action::PutBucketCorsAction)).await
|
||||
authorize_request(req, Action::S3Action(S3Action::DeleteBucketCorsAction)).await
|
||||
}
|
||||
|
||||
/// Checks whether the DeleteBucketEncryption request has accesses to the resources.
|
||||
|
||||
@@ -18,6 +18,7 @@ use crate::config::workload_profiles::{
|
||||
};
|
||||
use crate::error::ApiError;
|
||||
use crate::server::RemoteAddr;
|
||||
use crate::server::cors;
|
||||
use crate::storage::concurrency::{
|
||||
CachedGetObject, ConcurrencyManager, GetObjectGuard, get_concurrency_aware_buffer_size, get_concurrency_manager,
|
||||
};
|
||||
@@ -48,8 +49,8 @@ use rustfs_ecstore::{
|
||||
lifecycle::{self, Lifecycle, TransitionOptions},
|
||||
},
|
||||
metadata::{
|
||||
BUCKET_LIFECYCLE_CONFIG, BUCKET_NOTIFICATION_CONFIG, BUCKET_POLICY_CONFIG, BUCKET_REPLICATION_CONFIG,
|
||||
BUCKET_SSECONFIG, BUCKET_TAGGING_CONFIG, BUCKET_VERSIONING_CONFIG, OBJECT_LOCK_CONFIG,
|
||||
BUCKET_CORS_CONFIG, BUCKET_LIFECYCLE_CONFIG, BUCKET_NOTIFICATION_CONFIG, BUCKET_POLICY_CONFIG,
|
||||
BUCKET_REPLICATION_CONFIG, BUCKET_SSECONFIG, BUCKET_TAGGING_CONFIG, BUCKET_VERSIONING_CONFIG, OBJECT_LOCK_CONFIG,
|
||||
},
|
||||
metadata_sys,
|
||||
metadata_sys::get_replication_config,
|
||||
@@ -818,6 +819,205 @@ async fn get_validated_store(bucket: &str) -> S3Result<Arc<rustfs_ecstore::store
|
||||
Ok(store)
|
||||
}
|
||||
|
||||
/// Quick check if CORS processing is needed (lightweight check for Origin header)
|
||||
/// This avoids unnecessary function calls for non-CORS requests
|
||||
#[inline]
|
||||
fn needs_cors_processing(headers: &HeaderMap) -> bool {
|
||||
headers.contains_key(cors::standard::ORIGIN)
|
||||
}
|
||||
|
||||
/// Apply CORS headers to response based on bucket CORS configuration and request origin
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Reads the Origin header from the request
|
||||
/// 2. Retrieves the bucket's CORS configuration
|
||||
/// 3. Matches the origin against CORS rules
|
||||
/// 4. Validates AllowedHeaders if request headers are present
|
||||
/// 5. Returns headers to add to the response if a match is found
|
||||
///
|
||||
/// Note: This function should only be called if `needs_cors_processing()` returns true
|
||||
/// to avoid unnecessary overhead for non-CORS requests.
|
||||
pub(crate) async fn apply_cors_headers(bucket: &str, method: &http::Method, headers: &HeaderMap) -> Option<HeaderMap> {
|
||||
use http::HeaderValue;
|
||||
|
||||
// Get Origin header from request
|
||||
let origin = headers.get(cors::standard::ORIGIN)?.to_str().ok()?;
|
||||
|
||||
// Get CORS configuration for the bucket
|
||||
let cors_config = match metadata_sys::get_cors_config(bucket).await {
|
||||
Ok((config, _)) => config,
|
||||
Err(_) => return None, // No CORS config, no headers to add
|
||||
};
|
||||
|
||||
// Early return if no CORS rules configured
|
||||
if cors_config.cors_rules.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if method is supported and get its string representation
|
||||
const SUPPORTED_METHODS: &[&str] = &["GET", "PUT", "POST", "DELETE", "HEAD", "OPTIONS"];
|
||||
let method_str = method.as_str();
|
||||
if !SUPPORTED_METHODS.contains(&method_str) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// For OPTIONS (preflight) requests, check Access-Control-Request-Method
|
||||
let is_preflight = method == http::Method::OPTIONS;
|
||||
let requested_method = if is_preflight {
|
||||
headers
|
||||
.get(cors::request::ACCESS_CONTROL_REQUEST_METHOD)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or(method_str)
|
||||
} else {
|
||||
method_str
|
||||
};
|
||||
|
||||
// Get requested headers from preflight request
|
||||
let requested_headers = if is_preflight {
|
||||
headers
|
||||
.get(cors::request::ACCESS_CONTROL_REQUEST_HEADERS)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|h| h.split(',').map(|s| s.trim().to_lowercase()).collect::<Vec<_>>())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Find matching CORS rule
|
||||
for rule in cors_config.cors_rules.iter() {
|
||||
// Check if origin matches
|
||||
let origin_matches = rule.allowed_origins.iter().any(|allowed_origin| {
|
||||
if allowed_origin == "*" {
|
||||
true
|
||||
} else {
|
||||
// Exact match or pattern match (support wildcards like https://*.example.com)
|
||||
allowed_origin == origin || matches_origin_pattern(allowed_origin, origin)
|
||||
}
|
||||
});
|
||||
|
||||
if !origin_matches {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if method is allowed
|
||||
let method_allowed = rule
|
||||
.allowed_methods
|
||||
.iter()
|
||||
.any(|allowed_method| allowed_method.as_str() == requested_method);
|
||||
|
||||
if !method_allowed {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate AllowedHeaders if present in the request
|
||||
if let Some(ref req_headers) = requested_headers {
|
||||
if let Some(ref allowed_headers) = rule.allowed_headers {
|
||||
// Check if all requested headers are allowed
|
||||
let all_headers_allowed = req_headers.iter().all(|req_header| {
|
||||
allowed_headers.iter().any(|allowed_header| {
|
||||
let allowed_lower = allowed_header.to_lowercase();
|
||||
// "*" allows all headers, or exact match
|
||||
allowed_lower == "*" || allowed_lower == *req_header
|
||||
})
|
||||
});
|
||||
|
||||
if !all_headers_allowed {
|
||||
// If not all headers are allowed, skip this rule
|
||||
continue;
|
||||
}
|
||||
} else if !req_headers.is_empty() {
|
||||
// If no AllowedHeaders specified but headers were requested, skip this rule
|
||||
// Unless the rule explicitly allows all headers
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Found matching rule, build response headers
|
||||
let mut response_headers = HeaderMap::new();
|
||||
|
||||
// Access-Control-Allow-Origin
|
||||
// If origin is "*", use "*", otherwise echo back the origin
|
||||
let has_wildcard_origin = rule.allowed_origins.iter().any(|o| o == "*");
|
||||
if has_wildcard_origin {
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
|
||||
} else if let Ok(origin_value) = HeaderValue::from_str(origin) {
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_ORIGIN, origin_value);
|
||||
}
|
||||
|
||||
// Vary: Origin (required for caching, except when using wildcard)
|
||||
if !has_wildcard_origin {
|
||||
response_headers.insert(cors::standard::VARY, HeaderValue::from_static("Origin"));
|
||||
}
|
||||
|
||||
// Access-Control-Allow-Methods (required for preflight)
|
||||
if is_preflight || !rule.allowed_methods.is_empty() {
|
||||
let methods_str = rule.allowed_methods.iter().map(|m| m.as_str()).collect::<Vec<_>>().join(", ");
|
||||
if let Ok(methods_value) = HeaderValue::from_str(&methods_str) {
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_METHODS, methods_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Access-Control-Allow-Headers (required for preflight if headers were requested)
|
||||
if is_preflight && let Some(ref allowed_headers) = rule.allowed_headers {
|
||||
let headers_str = allowed_headers.iter().map(|h| h.as_str()).collect::<Vec<_>>().join(", ");
|
||||
if let Ok(headers_value) = HeaderValue::from_str(&headers_str) {
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_ALLOW_HEADERS, headers_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Access-Control-Expose-Headers (for actual requests)
|
||||
if !is_preflight && let Some(ref expose_headers) = rule.expose_headers {
|
||||
let expose_headers_str = expose_headers.iter().map(|h| h.as_str()).collect::<Vec<_>>().join(", ");
|
||||
if let Ok(expose_value) = HeaderValue::from_str(&expose_headers_str) {
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_EXPOSE_HEADERS, expose_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Access-Control-Max-Age (for preflight requests)
|
||||
if is_preflight
|
||||
&& let Some(max_age) = rule.max_age_seconds
|
||||
&& let Ok(max_age_value) = HeaderValue::from_str(&max_age.to_string())
|
||||
{
|
||||
response_headers.insert(cors::response::ACCESS_CONTROL_MAX_AGE, max_age_value);
|
||||
}
|
||||
|
||||
return Some(response_headers);
|
||||
}
|
||||
|
||||
None // No matching rule found
|
||||
}
|
||||
/// Check if an origin matches a pattern (supports wildcards like https://*.example.com)
|
||||
fn matches_origin_pattern(pattern: &str, origin: &str) -> bool {
|
||||
// Simple wildcard matching: * matches any sequence
|
||||
if pattern.contains('*') {
|
||||
let pattern_parts: Vec<&str> = pattern.split('*').collect();
|
||||
if pattern_parts.len() == 2 {
|
||||
origin.starts_with(pattern_parts[0]) && origin.ends_with(pattern_parts[1])
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
pattern == origin
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap S3Response with CORS headers if needed
|
||||
/// This function performs a lightweight check first to avoid unnecessary CORS processing
|
||||
/// for non-CORS requests (requests without Origin header)
|
||||
async fn wrap_response_with_cors<T>(bucket: &str, method: &http::Method, headers: &HeaderMap, output: T) -> S3Response<T> {
|
||||
let mut response = S3Response::new(output);
|
||||
|
||||
// Quick check: only process CORS if Origin header is present
|
||||
if needs_cors_processing(headers)
|
||||
&& let Some(cors_headers) = apply_cors_headers(bucket, method, headers).await
|
||||
{
|
||||
for (key, value) in cors_headers.iter() {
|
||||
response.headers.insert(key, value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl S3 for FS {
|
||||
#[instrument(
|
||||
@@ -2598,7 +2798,8 @@ impl S3 for FS {
|
||||
cache_key, response_content_length, total_duration, optimal_buffer_size
|
||||
);
|
||||
|
||||
let result = Ok(S3Response::new(output));
|
||||
let response = wrap_response_with_cors(&bucket, &req.method, &req.headers, output).await;
|
||||
let result = Ok(response);
|
||||
let _ = helper.complete(&result);
|
||||
result
|
||||
}
|
||||
@@ -2843,7 +3044,14 @@ impl S3 for FS {
|
||||
let version_id = req.input.version_id.clone().unwrap_or_default();
|
||||
helper = helper.object(event_info).version_id(version_id);
|
||||
|
||||
let result = Ok(S3Response::new(output));
|
||||
// NOTE ON CORS:
|
||||
// Bucket-level CORS headers are intentionally applied only for object retrieval
|
||||
// operations (GET/HEAD) via `wrap_response_with_cors`. Other S3 operations that
|
||||
// interact with objects (PUT/POST/DELETE/LIST, etc.) rely on the system-level
|
||||
// CORS layer instead. In case both are applicable, this bucket-level CORS logic
|
||||
// takes precedence for these read operations.
|
||||
let response = wrap_response_with_cors(&bucket, &req.method, &req.headers, output).await;
|
||||
let result = Ok(response);
|
||||
let _ = helper.complete(&result);
|
||||
|
||||
result
|
||||
@@ -4713,6 +4921,82 @@ impl S3 for FS {
|
||||
Ok(S3Response::new(DeleteBucketTaggingOutput {}))
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self))]
|
||||
async fn get_bucket_cors(&self, req: S3Request<GetBucketCorsInput>) -> S3Result<S3Response<GetBucketCorsOutput>> {
|
||||
let bucket = req.input.bucket.clone();
|
||||
// check bucket exists.
|
||||
let _bucket = self
|
||||
.head_bucket(req.map_input(|input| HeadBucketInput {
|
||||
bucket: input.bucket,
|
||||
expected_bucket_owner: None,
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let cors_configuration = match metadata_sys::get_cors_config(&bucket).await {
|
||||
Ok((config, _)) => config,
|
||||
Err(err) => {
|
||||
if err == StorageError::ConfigNotFound {
|
||||
return Err(S3Error::with_message(
|
||||
S3ErrorCode::NoSuchCORSConfiguration,
|
||||
"The CORS configuration does not exist".to_string(),
|
||||
));
|
||||
}
|
||||
warn!("get_cors_config err {:?}", &err);
|
||||
return Err(ApiError::from(err).into());
|
||||
}
|
||||
};
|
||||
|
||||
Ok(S3Response::new(GetBucketCorsOutput {
|
||||
cors_rules: Some(cors_configuration.cors_rules),
|
||||
}))
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self))]
|
||||
async fn put_bucket_cors(&self, req: S3Request<PutBucketCorsInput>) -> S3Result<S3Response<PutBucketCorsOutput>> {
|
||||
let PutBucketCorsInput {
|
||||
bucket,
|
||||
cors_configuration,
|
||||
..
|
||||
} = req.input;
|
||||
|
||||
let Some(store) = new_object_layer_fn() else {
|
||||
return Err(S3Error::with_message(S3ErrorCode::InternalError, "Not init".to_string()));
|
||||
};
|
||||
|
||||
store
|
||||
.get_bucket_info(&bucket, &BucketOptions::default())
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
|
||||
let data = try_!(serialize(&cors_configuration));
|
||||
|
||||
metadata_sys::update(&bucket, BUCKET_CORS_CONFIG, data)
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
|
||||
Ok(S3Response::new(PutBucketCorsOutput::default()))
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self))]
|
||||
async fn delete_bucket_cors(&self, req: S3Request<DeleteBucketCorsInput>) -> S3Result<S3Response<DeleteBucketCorsOutput>> {
|
||||
let DeleteBucketCorsInput { bucket, .. } = req.input;
|
||||
|
||||
let Some(store) = new_object_layer_fn() else {
|
||||
return Err(S3Error::with_message(S3ErrorCode::InternalError, "Not init".to_string()));
|
||||
};
|
||||
|
||||
store
|
||||
.get_bucket_info(&bucket, &BucketOptions::default())
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
|
||||
metadata_sys::delete(&bucket, BUCKET_CORS_CONFIG)
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
|
||||
Ok(S3Response::new(DeleteBucketCorsOutput {}))
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self, req))]
|
||||
async fn put_object_tagging(&self, req: S3Request<PutObjectTaggingInput>) -> S3Result<S3Response<PutObjectTaggingOutput>> {
|
||||
let start_time = std::time::Instant::now();
|
||||
@@ -6557,4 +6841,201 @@ mod tests {
|
||||
assert!(filtered_version_marker.is_some());
|
||||
assert_eq!(filtered_version_marker.unwrap(), "null");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_origin_pattern_exact_match() {
|
||||
// Test exact match
|
||||
assert!(matches_origin_pattern("https://example.com", "https://example.com"));
|
||||
assert!(matches_origin_pattern("http://localhost:3000", "http://localhost:3000"));
|
||||
assert!(!matches_origin_pattern("https://example.com", "https://other.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_origin_pattern_wildcard() {
|
||||
// Test wildcard pattern matching (S3 CORS supports * as subdomain wildcard)
|
||||
assert!(matches_origin_pattern("https://*.example.com", "https://app.example.com"));
|
||||
assert!(matches_origin_pattern("https://*.example.com", "https://api.example.com"));
|
||||
assert!(matches_origin_pattern("https://*.example.com", "https://subdomain.example.com"));
|
||||
|
||||
// Test wildcard at start (matches any domain)
|
||||
assert!(matches_origin_pattern("https://*", "https://example.com"));
|
||||
assert!(matches_origin_pattern("https://*", "https://any-domain.com"));
|
||||
|
||||
// Test wildcard at end (matches any protocol)
|
||||
assert!(matches_origin_pattern("*://example.com", "https://example.com"));
|
||||
assert!(matches_origin_pattern("*://example.com", "http://example.com"));
|
||||
|
||||
// Test invalid wildcard patterns (should not match)
|
||||
assert!(!matches_origin_pattern("https://*.*.com", "https://app.example.com")); // Multiple wildcards (invalid pattern)
|
||||
// Note: "https://*example.com" actually matches "https://app.example.com" with our current implementation
|
||||
// because it splits on * and checks starts_with/ends_with. This is a limitation but acceptable
|
||||
// for S3 CORS which typically uses patterns like "https://*.example.com"
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_origin_pattern_no_wildcard() {
|
||||
// Test patterns without wildcards
|
||||
assert!(matches_origin_pattern("https://example.com", "https://example.com"));
|
||||
assert!(!matches_origin_pattern("https://example.com", "https://example.org"));
|
||||
assert!(!matches_origin_pattern("http://example.com", "https://example.com")); // Different protocol
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_origin_pattern_edge_cases() {
|
||||
// Test edge cases
|
||||
assert!(!matches_origin_pattern("", "https://example.com")); // Empty pattern
|
||||
assert!(!matches_origin_pattern("https://example.com", "")); // Empty origin
|
||||
assert!(matches_origin_pattern("", "")); // Both empty
|
||||
assert!(!matches_origin_pattern("https://example.com", "http://example.com")); // Protocol mismatch
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cors_headers_validation() {
|
||||
use http::HeaderMap;
|
||||
|
||||
// Test case 1: Validate header name case-insensitivity
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("access-control-request-headers", "Content-Type,X-Custom-Header".parse().unwrap());
|
||||
|
||||
let req_headers_str = headers
|
||||
.get("access-control-request-headers")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap();
|
||||
let req_headers: Vec<String> = req_headers_str.split(',').map(|s| s.trim().to_lowercase()).collect();
|
||||
|
||||
// Headers should be lowercased for comparison
|
||||
assert_eq!(req_headers, vec!["content-type", "x-custom-header"]);
|
||||
|
||||
// Test case 2: Wildcard matching
|
||||
let allowed_headers = ["*".to_string()];
|
||||
let all_allowed = req_headers.iter().all(|req_header| {
|
||||
allowed_headers
|
||||
.iter()
|
||||
.any(|allowed| allowed.to_lowercase() == "*" || allowed.to_lowercase() == *req_header)
|
||||
});
|
||||
assert!(all_allowed, "Wildcard should allow all headers");
|
||||
|
||||
// Test case 3: Specific header matching
|
||||
let allowed_headers = ["content-type".to_string(), "x-custom-header".to_string()];
|
||||
let all_allowed = req_headers
|
||||
.iter()
|
||||
.all(|req_header| allowed_headers.iter().any(|allowed| allowed.to_lowercase() == *req_header));
|
||||
assert!(all_allowed, "All requested headers should be allowed");
|
||||
|
||||
// Test case 4: Disallowed header
|
||||
let req_headers = ["content-type".to_string(), "x-forbidden-header".to_string()];
|
||||
let allowed_headers = ["content-type".to_string()];
|
||||
let all_allowed = req_headers
|
||||
.iter()
|
||||
.all(|req_header| allowed_headers.iter().any(|allowed| allowed.to_lowercase() == *req_header));
|
||||
assert!(!all_allowed, "Forbidden header should not be allowed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cors_response_headers_structure() {
|
||||
use http::{HeaderMap, HeaderValue};
|
||||
|
||||
let mut cors_headers = HeaderMap::new();
|
||||
|
||||
// Simulate building CORS response headers
|
||||
let origin = "https://example.com";
|
||||
let methods = ["GET", "PUT", "POST"];
|
||||
let allowed_headers = ["Content-Type", "Authorization"];
|
||||
let expose_headers = ["ETag", "x-amz-version-id"];
|
||||
let max_age = 3600;
|
||||
|
||||
// Add headers
|
||||
cors_headers.insert("access-control-allow-origin", HeaderValue::from_str(origin).unwrap());
|
||||
cors_headers.insert("vary", HeaderValue::from_static("Origin"));
|
||||
|
||||
let methods_str = methods.join(", ");
|
||||
cors_headers.insert("access-control-allow-methods", HeaderValue::from_str(&methods_str).unwrap());
|
||||
|
||||
let headers_str = allowed_headers.join(", ");
|
||||
cors_headers.insert("access-control-allow-headers", HeaderValue::from_str(&headers_str).unwrap());
|
||||
|
||||
let expose_str = expose_headers.join(", ");
|
||||
cors_headers.insert("access-control-expose-headers", HeaderValue::from_str(&expose_str).unwrap());
|
||||
|
||||
cors_headers.insert("access-control-max-age", HeaderValue::from_str(&max_age.to_string()).unwrap());
|
||||
|
||||
// Verify all headers are present
|
||||
assert_eq!(cors_headers.get("access-control-allow-origin").unwrap(), origin);
|
||||
assert_eq!(cors_headers.get("vary").unwrap(), "Origin");
|
||||
assert_eq!(cors_headers.get("access-control-allow-methods").unwrap(), "GET, PUT, POST");
|
||||
assert_eq!(cors_headers.get("access-control-allow-headers").unwrap(), "Content-Type, Authorization");
|
||||
assert_eq!(cors_headers.get("access-control-expose-headers").unwrap(), "ETag, x-amz-version-id");
|
||||
assert_eq!(cors_headers.get("access-control-max-age").unwrap(), "3600");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cors_preflight_vs_actual_request() {
|
||||
use http::Method;
|
||||
|
||||
// Test that we can distinguish preflight from actual requests
|
||||
let preflight_method = Method::OPTIONS;
|
||||
let actual_method = Method::PUT;
|
||||
|
||||
assert_eq!(preflight_method, Method::OPTIONS);
|
||||
assert_ne!(actual_method, Method::OPTIONS);
|
||||
|
||||
// Preflight should check Access-Control-Request-Method
|
||||
// Actual request should use the actual method
|
||||
let is_preflight_1 = preflight_method == Method::OPTIONS;
|
||||
let is_preflight_2 = actual_method == Method::OPTIONS;
|
||||
|
||||
assert!(is_preflight_1);
|
||||
assert!(!is_preflight_2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_cors_headers_no_origin() {
|
||||
// Test when no Origin header is present
|
||||
let headers = HeaderMap::new();
|
||||
let method = http::Method::GET;
|
||||
|
||||
// Should return None when no origin header
|
||||
let result = apply_cors_headers("test-bucket", &method, &headers).await;
|
||||
assert!(result.is_none(), "Should return None when no Origin header");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_cors_headers_no_cors_config() {
|
||||
// Test when bucket has no CORS configuration
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("origin", "https://example.com".parse().unwrap());
|
||||
let method = http::Method::GET;
|
||||
|
||||
// Should return None when no CORS config exists
|
||||
// Note: This test may fail if test-bucket actually has CORS config
|
||||
// In a real scenario, we'd use a mock or ensure the bucket doesn't exist
|
||||
let _result = apply_cors_headers("non-existent-bucket-for-testing", &method, &headers).await;
|
||||
// Result depends on whether bucket exists and has CORS config
|
||||
// This is expected behavior - we just verify it doesn't panic
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_cors_headers_unsupported_method() {
|
||||
// Test with unsupported HTTP method
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("origin", "https://example.com".parse().unwrap());
|
||||
let method = http::Method::PATCH; // Unsupported method
|
||||
|
||||
let result = apply_cors_headers("test-bucket", &method, &headers).await;
|
||||
assert!(result.is_none(), "Should return None for unsupported methods");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_origin_pattern_complex_wildcards() {
|
||||
// Test more complex wildcard scenarios
|
||||
assert!(matches_origin_pattern("https://*.example.com", "https://sub.example.com"));
|
||||
// Note: "https://*.example.com" matches "https://api.sub.example.com" with our implementation
|
||||
// because it only checks starts_with and ends_with. Real S3 might be more strict.
|
||||
|
||||
// Test wildcard in middle position
|
||||
// Our implementation allows this, but it's not standard S3 CORS pattern
|
||||
// The pattern "https://example.*.com" splits to ["https://example.", ".com"]
|
||||
// and "https://example.sub.com" matches because it starts with "https://example." and ends with ".com"
|
||||
// This is acceptable for our use case as S3 CORS typically uses "https://*.example.com" format
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user