Compare commits

...

15 Commits

Author SHA1 Message Date
weisd
56fd8132e9 fix:#303 returns empty when querying an empty or not dir (#304) 2025-07-28 16:17:40 +08:00
guojidan
35daa74430 Merge pull request #302 from guojidan/lock
Lock: add transactional
2025-07-28 12:00:44 +08:00
junxiang Mu
dc156fb4cd Fix: clippy
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-28 11:38:42 +08:00
junxiang Mu
de905a878c Cargo: use workspace dependence
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-28 11:02:40 +08:00
junxiang Mu
f3252f989b Test: Add e2e test case for lock transactional
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-28 11:00:10 +08:00
junxiang Mu
01a2afca9a lock: Add transactional
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-28 10:59:43 +08:00
guojidan
a4fe68ad21 Merge pull request #301 from guojidan/improve-sql
s3Select: add unit test case
2025-07-28 09:56:10 +08:00
junxiang Mu
c03f86b23c s3Select: add unit test case
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-28 09:19:47 +08:00
guojidan
5667f324ae Merge pull request #297 from guojidan/improve-sql
Test: Add e2e_test case for sql && add script for e2e_test
2025-07-25 17:16:41 +08:00
junxiang Mu
bcd806796f Test: add test script for e2e
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-25 16:52:06 +08:00
junxiang Mu
612404c47f Test: add e2e_test for s3select
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-25 15:07:44 +08:00
guojidan
85388262b3 Merge pull request #294 from guojidan/improve-sql
Refactor: DatabaseManagerSystem as global
2025-07-25 08:33:54 +08:00
junxiang Mu
25a4503285 fix: fmt
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-25 08:18:14 +08:00
安正超
526c4d5a61 refactor: 优化构建工作流,统一 latest 文件处理和简化制品上传 (#293) 2025-07-25 01:10:04 +08:00
junxiang Mu
addc964d56 Refactor: DatabaseManagerSystem as global
Signed-off-by: junxiang Mu <1948535941@qq.com>
2025-07-24 17:12:51 +08:00
19 changed files with 2195 additions and 117 deletions

View File

@@ -383,20 +383,66 @@ jobs:
exit 1
fi
# Create latest version files right after the main package
LATEST_FILES=""
if [[ "$BUILD_TYPE" == "release" ]] || [[ "$BUILD_TYPE" == "prerelease" ]]; then
# Create latest version filename
# Convert from rustfs-linux-x86_64-musl-v1.0.0 to rustfs-linux-x86_64-musl-latest
LATEST_FILE="${PACKAGE_NAME%-v*}-latest.zip"
echo "🔄 Creating latest version: ${PACKAGE_NAME}.zip -> $LATEST_FILE"
cp "${PACKAGE_NAME}.zip" "$LATEST_FILE"
if [[ -f "$LATEST_FILE" ]]; then
echo "✅ Latest version created: $LATEST_FILE"
LATEST_FILES="$LATEST_FILE"
fi
elif [[ "$BUILD_TYPE" == "development" ]]; then
# Development builds (only main branch triggers development builds)
# Create main-latest version filename
# Convert from rustfs-linux-x86_64-dev-abc123 to rustfs-linux-x86_64-main-latest
MAIN_LATEST_FILE="${PACKAGE_NAME%-dev-*}-main-latest.zip"
echo "🔄 Creating main-latest version: ${PACKAGE_NAME}.zip -> $MAIN_LATEST_FILE"
cp "${PACKAGE_NAME}.zip" "$MAIN_LATEST_FILE"
if [[ -f "$MAIN_LATEST_FILE" ]]; then
echo "✅ Main-latest version created: $MAIN_LATEST_FILE"
LATEST_FILES="$MAIN_LATEST_FILE"
# Also create a generic main-latest for Docker builds (Linux only)
if [[ "${{ matrix.platform }}" == "linux" ]]; then
DOCKER_MAIN_LATEST_FILE="rustfs-linux-${ARCH_WITH_VARIANT}-main-latest.zip"
echo "🔄 Creating Docker main-latest version: ${PACKAGE_NAME}.zip -> $DOCKER_MAIN_LATEST_FILE"
cp "${PACKAGE_NAME}.zip" "$DOCKER_MAIN_LATEST_FILE"
if [[ -f "$DOCKER_MAIN_LATEST_FILE" ]]; then
echo "✅ Docker main-latest version created: $DOCKER_MAIN_LATEST_FILE"
LATEST_FILES="$LATEST_FILES $DOCKER_MAIN_LATEST_FILE"
fi
fi
fi
fi
echo "package_name=${PACKAGE_NAME}" >> $GITHUB_OUTPUT
echo "package_file=${PACKAGE_NAME}.zip" >> $GITHUB_OUTPUT
echo "latest_files=${LATEST_FILES}" >> $GITHUB_OUTPUT
echo "build_type=${BUILD_TYPE}" >> $GITHUB_OUTPUT
echo "version=${VERSION}" >> $GITHUB_OUTPUT
echo "📦 Package created: ${PACKAGE_NAME}.zip"
if [[ -n "$LATEST_FILES" ]]; then
echo "📦 Latest files created: $LATEST_FILES"
fi
echo "🔧 Build type: ${BUILD_TYPE}"
echo "📊 Version: ${VERSION}"
- name: Upload artifacts
- name: Upload to GitHub artifacts
uses: actions/upload-artifact@v4
with:
name: ${{ steps.package.outputs.package_name }}
path: ${{ steps.package.outputs.package_file }}
path: "rustfs-*.zip"
retention-days: ${{ startsWith(github.ref, 'refs/tags/') && 30 || 7 }}
- name: Upload to Aliyun OSS
@@ -466,73 +512,15 @@ jobs:
echo "📤 Uploading release build to OSS release directory"
fi
# Upload the package file to OSS
echo "Uploading ${{ steps.package.outputs.package_file }} to $OSS_PATH..."
$OSSUTIL_BIN cp "${{ steps.package.outputs.package_file }}" "$OSS_PATH" --force
# For release and prerelease builds, also create a latest version
if [[ "$BUILD_TYPE" == "release" ]] || [[ "$BUILD_TYPE" == "prerelease" ]]; then
# Extract platform and arch from package name
PACKAGE_NAME="${{ steps.package.outputs.package_name }}"
# Create latest version filename
# Convert from rustfs-linux-x86_64-v1.0.0 to rustfs-linux-x86_64-latest
LATEST_FILE="${PACKAGE_NAME%-v*}-latest.zip"
# Copy the original file to latest version
cp "${{ steps.package.outputs.package_file }}" "$LATEST_FILE"
# Upload the latest version
echo "Uploading latest version: $LATEST_FILE to $OSS_PATH..."
$OSSUTIL_BIN cp "$LATEST_FILE" "$OSS_PATH" --force
echo "✅ Latest version uploaded: $LATEST_FILE"
fi
# For development builds, create dev-latest version
if [[ "$BUILD_TYPE" == "development" ]]; then
# Extract platform and arch from package name
PACKAGE_NAME="${{ steps.package.outputs.package_name }}"
# Create dev-latest version filename
# Convert from rustfs-linux-x86_64-dev-abc123 to rustfs-linux-x86_64-dev-latest
DEV_LATEST_FILE="${PACKAGE_NAME%-*}-latest.zip"
# Copy the original file to dev-latest version
cp "${{ steps.package.outputs.package_file }}" "$DEV_LATEST_FILE"
# Upload the dev-latest version
echo "Uploading dev-latest version: $DEV_LATEST_FILE to $OSS_PATH..."
$OSSUTIL_BIN cp "$DEV_LATEST_FILE" "$OSS_PATH" --force
echo "✅ Dev-latest version uploaded: $DEV_LATEST_FILE"
# For main branch builds, also create a main-latest version
if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then
# Create main-latest version filename
# Convert from rustfs-linux-x86_64-dev-abc123 to rustfs-linux-x86_64-main-latest
MAIN_LATEST_FILE="${PACKAGE_NAME%-dev-*}-main-latest.zip"
# Copy the original file to main-latest version
cp "${{ steps.package.outputs.package_file }}" "$MAIN_LATEST_FILE"
# Upload the main-latest version
echo "Uploading main-latest version: $MAIN_LATEST_FILE to $OSS_PATH..."
$OSSUTIL_BIN cp "$MAIN_LATEST_FILE" "$OSS_PATH" --force
echo "✅ Main-latest version uploaded: $MAIN_LATEST_FILE"
# Also create a generic main-latest for Docker builds
if [[ "${{ matrix.platform }}" == "linux" ]]; then
# Use the same ARCH_WITH_VARIANT logic for Docker files
DOCKER_MAIN_LATEST_FILE="rustfs-linux-${ARCH_WITH_VARIANT}-main-latest.zip"
cp "${{ steps.package.outputs.package_file }}" "$DOCKER_MAIN_LATEST_FILE"
$OSSUTIL_BIN cp "$DOCKER_MAIN_LATEST_FILE" "$OSS_PATH" --force
echo "✅ Docker main-latest version uploaded: $DOCKER_MAIN_LATEST_FILE"
fi
# Upload all rustfs zip files to OSS using glob pattern
echo "📤 Uploading all rustfs-*.zip files to $OSS_PATH..."
for zip_file in rustfs-*.zip; do
if [[ -f "$zip_file" ]]; then
echo "Uploading: $zip_file to $OSS_PATH..."
$OSSUTIL_BIN cp "$zip_file" "$OSS_PATH" --force
echo "✅ Uploaded: $zip_file"
fi
fi
done
echo "✅ Upload completed successfully"
@@ -703,7 +691,7 @@ jobs:
mkdir -p ./release-assets
# Copy and verify artifacts
# Copy and verify artifacts (including latest files created during build)
ASSETS_COUNT=0
for file in ./artifacts/*.zip; do
if [[ -f "$file" ]]; then
@@ -719,7 +707,7 @@ jobs:
cd ./release-assets
# Generate checksums
# Generate checksums for all files (including latest versions)
if ls *.zip >/dev/null 2>&1; then
sha256sum *.zip > SHA256SUMS
sha512sum *.zip > SHA512SUMS
@@ -734,7 +722,7 @@ jobs:
echo "📦 Prepared assets:"
ls -la
echo "🔢 Asset count: $ASSETS_COUNT"
echo "🔢 Total asset count: $ASSETS_COUNT"
- name: Upload to GitHub Release
env:

118
Cargo.lock generated
View File

@@ -672,6 +672,36 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "aws-config"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0baa720ebadea158c5bda642ac444a2af0cdf7bb66b46d1e4533de5d1f449d0"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-sdk-sso",
"aws-sdk-ssooidc",
"aws-sdk-sts",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"hex",
"http 1.3.1",
"ring",
"time",
"tokio",
"tracing",
"url",
"zeroize",
]
[[package]]
name = "aws-credential-types"
version = "1.2.4"
@@ -734,9 +764,9 @@ dependencies = [
[[package]]
name = "aws-sdk-s3"
version = "1.98.0"
version = "1.99.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "029e89cae7e628553643aecb3a3f054a0a0912ff0fd1f5d6a0b4fda421dce64b"
checksum = "b2d64d68c93000d5792b2a25fbeaafb90985fa80a1c8adfe93f24fb271296f5f"
dependencies = [
"aws-credential-types",
"aws-runtime",
@@ -766,6 +796,73 @@ dependencies = [
"url",
]
[[package]]
name = "aws-sdk-sso"
version = "1.77.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18f2f37fea82468fe3f5a059542c05392ef680c4f7f00e0db02df8b6e5c7d0c6"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"http 0.2.12",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-ssooidc"
version = "1.78.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecb4f6eada20e0193450cd48b12ed05e1e66baac86f39160191651b932f2b7d9"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"http 0.2.12",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-sts"
version = "1.79.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "317377afba3498fca4948c5d32b399ef9a5ad35561a1e8a6f2ac7273dabf802d"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-query",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-smithy-xml",
"aws-types",
"fastrand",
"http 0.2.12",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sigv4"
version = "1.3.3"
@@ -904,6 +1001,16 @@ dependencies = [
"aws-smithy-runtime-api",
]
[[package]]
name = "aws-smithy-query"
version = "0.60.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb"
dependencies = [
"aws-smithy-types",
"urlencoding",
]
[[package]]
name = "aws-smithy-runtime"
version = "1.8.5"
@@ -982,9 +1089,9 @@ dependencies = [
[[package]]
name = "aws-types"
version = "1.3.7"
version = "1.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a"
checksum = "b069d19bf01e46298eaedd7c6f283fe565a59263e53eebec945f3e6398f42390"
dependencies = [
"aws-credential-types",
"aws-smithy-async",
@@ -3466,6 +3573,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
name = "e2e_test"
version = "0.0.5"
dependencies = [
"async-trait",
"aws-config",
"aws-sdk-s3",
"bytes",
"flatbuffers 25.2.10",
"futures",

View File

@@ -39,3 +39,6 @@ rustfs-madmin.workspace = true
rustfs-filemeta.workspace = true
bytes.workspace = true
serial_test = "3.2.0"
aws-sdk-s3 = "1.99.0"
aws-config = "1.8.3"
async-trait = { workspace = true }

View File

@@ -13,12 +13,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use rustfs_ecstore::{disk::endpoint::Endpoint, lock_utils::create_unique_clients};
use rustfs_lock::client::{LockClient, local::LocalClient};
use rustfs_lock::types::{LockInfo, LockResponse, LockStats};
use rustfs_lock::{LockId, LockMetadata, LockPriority, LockType};
use rustfs_lock::{LockRequest, NamespaceLock, NamespaceLockManager};
use rustfs_protos::{node_service_time_out_client, proto_gen::node_service::GenerallyLockRequest};
use serial_test::serial;
use std::{error::Error, time::Duration};
use std::{error::Error, sync::Arc, time::Duration};
use tokio::time::sleep;
use tonic::Request;
use url::Url;
@@ -72,6 +75,216 @@ async fn test_lock_unlock_rpc() -> Result<(), Box<dyn Error>> {
Ok(())
}
/// Mock client that simulates remote node failures
#[derive(Debug)]
struct FailingMockClient {
local_client: Arc<dyn LockClient>,
should_fail_acquire: bool,
should_fail_release: bool,
}
impl FailingMockClient {
fn new(should_fail_acquire: bool, should_fail_release: bool) -> Self {
Self {
local_client: Arc::new(LocalClient::new()),
should_fail_acquire,
should_fail_release,
}
}
}
#[async_trait]
impl LockClient for FailingMockClient {
async fn acquire_exclusive(&self, request: &LockRequest) -> rustfs_lock::error::Result<LockResponse> {
if self.should_fail_acquire {
// Simulate network timeout or remote node failure
return Ok(LockResponse::failure("Simulated remote node failure", Duration::from_millis(100)));
}
self.local_client.acquire_exclusive(request).await
}
async fn acquire_shared(&self, request: &LockRequest) -> rustfs_lock::error::Result<LockResponse> {
if self.should_fail_acquire {
return Ok(LockResponse::failure("Simulated remote node failure", Duration::from_millis(100)));
}
self.local_client.acquire_shared(request).await
}
async fn release(&self, lock_id: &LockId) -> rustfs_lock::error::Result<bool> {
if self.should_fail_release {
return Err(rustfs_lock::error::LockError::internal("Simulated release failure"));
}
self.local_client.release(lock_id).await
}
async fn refresh(&self, lock_id: &LockId) -> rustfs_lock::error::Result<bool> {
self.local_client.refresh(lock_id).await
}
async fn force_release(&self, lock_id: &LockId) -> rustfs_lock::error::Result<bool> {
self.local_client.force_release(lock_id).await
}
async fn check_status(&self, lock_id: &LockId) -> rustfs_lock::error::Result<Option<LockInfo>> {
self.local_client.check_status(lock_id).await
}
async fn get_stats(&self) -> rustfs_lock::error::Result<LockStats> {
self.local_client.get_stats().await
}
async fn close(&self) -> rustfs_lock::error::Result<()> {
self.local_client.close().await
}
async fn is_online(&self) -> bool {
if self.should_fail_acquire {
return false; // Simulate offline node
}
true // Simulate online node
}
async fn is_local(&self) -> bool {
false // Simulate remote client
}
}
#[tokio::test]
#[serial]
async fn test_transactional_lock_with_remote_failure() -> Result<(), Box<dyn Error>> {
println!("🧪 Testing transactional lock with simulated remote node failure");
// Create a two-node cluster: one local (success) + one remote (failure)
let local_client: Arc<dyn LockClient> = Arc::new(LocalClient::new());
let failing_remote_client: Arc<dyn LockClient> = Arc::new(FailingMockClient::new(true, false));
let clients = vec![local_client, failing_remote_client];
let ns_lock = NamespaceLock::with_clients("test_transactional".to_string(), clients);
let resource = "critical_resource".to_string();
// Test single lock operation with 2PC
println!("📝 Testing single lock with remote failure...");
let request = LockRequest::new(&resource, LockType::Exclusive, "test_owner").with_ttl(Duration::from_secs(30));
let response = ns_lock.acquire_lock(&request).await?;
// Should fail because quorum (2/2) is not met due to remote failure
assert!(!response.success, "Lock should fail due to remote node failure");
println!("✅ Single lock correctly failed due to remote node failure");
// Verify no locks are left behind on the local node
let local_client_direct = LocalClient::new();
let lock_id = LockId::new_deterministic(&ns_lock.get_resource_key(&resource));
let lock_status = local_client_direct.check_status(&lock_id).await?;
assert!(lock_status.is_none(), "No lock should remain on local node after rollback");
println!("✅ Verified rollback: no locks left on local node");
Ok(())
}
#[tokio::test]
#[serial]
async fn test_transactional_batch_lock_with_mixed_failures() -> Result<(), Box<dyn Error>> {
println!("🧪 Testing transactional batch lock with mixed node failures");
// Create a cluster with different failure patterns
let local_client: Arc<dyn LockClient> = Arc::new(LocalClient::new());
let failing_remote_client: Arc<dyn LockClient> = Arc::new(FailingMockClient::new(true, false));
let clients = vec![local_client, failing_remote_client];
let ns_lock = NamespaceLock::with_clients("test_batch_transactional".to_string(), clients);
let resources = vec!["resource_1".to_string(), "resource_2".to_string(), "resource_3".to_string()];
println!("📝 Testing batch lock with remote failure...");
let result = ns_lock
.lock_batch(&resources, "batch_owner", Duration::from_millis(100), Duration::from_secs(30))
.await?;
// Should fail because remote node cannot acquire locks
assert!(!result, "Batch lock should fail due to remote node failure");
println!("✅ Batch lock correctly failed due to remote node failure");
// Verify no locks are left behind on any resource
let local_client_direct = LocalClient::new();
for resource in &resources {
let lock_id = LockId::new_deterministic(&ns_lock.get_resource_key(resource));
let lock_status = local_client_direct.check_status(&lock_id).await?;
assert!(lock_status.is_none(), "No lock should remain for resource: {resource}");
}
println!("✅ Verified rollback: no locks left on any resource");
Ok(())
}
#[tokio::test]
#[serial]
async fn test_transactional_lock_with_quorum_success() -> Result<(), Box<dyn Error>> {
println!("🧪 Testing transactional lock with quorum success");
// Create a three-node cluster where 2 succeed and 1 fails (quorum = 2 automatically)
let local_client1: Arc<dyn LockClient> = Arc::new(LocalClient::new());
let local_client2: Arc<dyn LockClient> = Arc::new(LocalClient::new());
let failing_remote_client: Arc<dyn LockClient> = Arc::new(FailingMockClient::new(true, false));
let clients = vec![local_client1, local_client2, failing_remote_client];
let ns_lock = NamespaceLock::with_clients("test_quorum".to_string(), clients);
let resource = "quorum_resource".to_string();
println!("📝 Testing lock with automatic quorum=2, 2 success + 1 failure...");
let request = LockRequest::new(&resource, LockType::Exclusive, "quorum_owner").with_ttl(Duration::from_secs(30));
let response = ns_lock.acquire_lock(&request).await?;
// Should fail because we require all nodes to succeed for consistency
// (even though quorum is met, the implementation requires all nodes for consistency)
assert!(!response.success, "Lock should fail due to consistency requirement");
println!("✅ Lock correctly failed due to consistency requirement (partial success rolled back)");
Ok(())
}
#[tokio::test]
#[serial]
async fn test_transactional_lock_rollback_on_release_failure() -> Result<(), Box<dyn Error>> {
println!("🧪 Testing rollback behavior when release fails");
// Create clients where acquire succeeds but release fails
let local_client: Arc<dyn LockClient> = Arc::new(LocalClient::new());
let failing_release_client: Arc<dyn LockClient> = Arc::new(FailingMockClient::new(false, true));
let clients = vec![local_client, failing_release_client];
let ns_lock = NamespaceLock::with_clients("test_release_failure".to_string(), clients);
let resource = "release_test_resource".to_string();
println!("📝 Testing lock acquisition with release failure handling...");
let request = LockRequest::new(&resource, LockType::Exclusive, "test_owner").with_ttl(Duration::from_secs(30));
// This should fail because both LocalClient instances share the same global lock map
// The first client (LocalClient) will acquire the lock, but the second client
// (FailingMockClient's internal LocalClient) will fail to acquire the same resource
let response = ns_lock.acquire_lock(&request).await?;
// The operation should fail due to lock contention between the two LocalClient instances
assert!(
!response.success,
"Lock should fail due to lock contention between LocalClient instances sharing global lock map"
);
println!("✅ Lock correctly failed due to lock contention (both clients use same global lock map)");
// Verify no locks are left behind after rollback
let local_client_direct = LocalClient::new();
let lock_id = LockId::new_deterministic(&ns_lock.get_resource_key(&resource));
let lock_status = local_client_direct.check_status(&lock_id).await?;
assert!(lock_status.is_none(), "No lock should remain after rollback");
println!("✅ Verified rollback: no locks left after failed acquisition");
Ok(())
}
#[tokio::test]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]

View File

@@ -14,3 +14,4 @@
mod lock;
mod node_interact_test;
mod sql;

View File

@@ -0,0 +1,402 @@
#![cfg(test)]
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_s3::Client;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::types::{
CsvInput, CsvOutput, ExpressionType, FileHeaderInfo, InputSerialization, JsonInput, JsonOutput, JsonType, OutputSerialization,
};
use bytes::Bytes;
use serial_test::serial;
use std::error::Error;
const ENDPOINT: &str = "http://localhost:9000";
const ACCESS_KEY: &str = "rustfsadmin";
const SECRET_KEY: &str = "rustfsadmin";
const BUCKET: &str = "test-sql-bucket";
const CSV_OBJECT: &str = "test-data.csv";
const JSON_OBJECT: &str = "test-data.json";
async fn create_aws_s3_client() -> Result<Client, Box<dyn Error>> {
let region_provider = RegionProviderChain::default_provider().or_else(Region::new("us-east-1"));
let shared_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(region_provider)
.credentials_provider(Credentials::new(ACCESS_KEY, SECRET_KEY, None, None, "static"))
.endpoint_url(ENDPOINT)
.load()
.await;
let client = Client::from_conf(
aws_sdk_s3::Config::from(&shared_config)
.to_builder()
.force_path_style(true) // Important for S3-compatible services
.build(),
);
Ok(client)
}
async fn setup_test_bucket(client: &Client) -> Result<(), Box<dyn Error>> {
match client.create_bucket().bucket(BUCKET).send().await {
Ok(_) => {}
Err(e) => {
let error_str = e.to_string();
if !error_str.contains("BucketAlreadyOwnedByYou") && !error_str.contains("BucketAlreadyExists") {
return Err(e.into());
}
}
}
Ok(())
}
async fn upload_test_csv(client: &Client) -> Result<(), Box<dyn Error>> {
let csv_data = "name,age,city\nAlice,30,New York\nBob,25,Los Angeles\nCharlie,35,Chicago\nDiana,28,Boston";
client
.put_object()
.bucket(BUCKET)
.key(CSV_OBJECT)
.body(Bytes::from(csv_data.as_bytes()).into())
.send()
.await?;
Ok(())
}
async fn upload_test_json(client: &Client) -> Result<(), Box<dyn Error>> {
let json_data = r#"{"name":"Alice","age":30,"city":"New York"}
{"name":"Bob","age":25,"city":"Los Angeles"}
{"name":"Charlie","age":35,"city":"Chicago"}
{"name":"Diana","age":28,"city":"Boston"}"#;
client
.put_object()
.bucket(BUCKET)
.key(JSON_OBJECT)
.body(Bytes::from(json_data.as_bytes()).into())
.send()
.await?;
Ok(())
}
async fn process_select_response(
mut event_stream: aws_sdk_s3::operation::select_object_content::SelectObjectContentOutput,
) -> Result<String, Box<dyn Error>> {
let mut total_data = Vec::new();
while let Ok(Some(event)) = event_stream.payload.recv().await {
match event {
aws_sdk_s3::types::SelectObjectContentEventStream::Records(records_event) => {
if let Some(payload) = records_event.payload {
let data = payload.into_inner();
total_data.extend_from_slice(&data);
}
}
aws_sdk_s3::types::SelectObjectContentEventStream::End(_) => {
break;
}
_ => {
// Handle other event types (Stats, Progress, Cont, etc.)
}
}
}
Ok(String::from_utf8(total_data)?)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]
async fn test_select_object_content_csv_basic() -> Result<(), Box<dyn Error>> {
let client = create_aws_s3_client().await?;
setup_test_bucket(&client).await?;
upload_test_csv(&client).await?;
// Construct SelectObjectContent request - basic query
let sql = "SELECT * FROM S3Object WHERE age > 28";
let csv_input = CsvInput::builder().file_header_info(FileHeaderInfo::Use).build();
let input_serialization = InputSerialization::builder().csv(csv_input).build();
let csv_output = CsvOutput::builder().build();
let output_serialization = OutputSerialization::builder().csv(csv_output).build();
let response = client
.select_object_content()
.bucket(BUCKET)
.key(CSV_OBJECT)
.expression(sql)
.expression_type(ExpressionType::Sql)
.input_serialization(input_serialization)
.output_serialization(output_serialization)
.send()
.await?;
let result_str = process_select_response(response).await?;
println!("CSV Select result: {result_str}");
// Verify results contain records with age > 28
assert!(result_str.contains("Alice,30,New York"));
assert!(result_str.contains("Charlie,35,Chicago"));
assert!(!result_str.contains("Bob,25,Los Angeles"));
assert!(!result_str.contains("Diana,28,Boston"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]
async fn test_select_object_content_csv_aggregation() -> Result<(), Box<dyn Error>> {
let client = create_aws_s3_client().await?;
setup_test_bucket(&client).await?;
upload_test_csv(&client).await?;
// Construct aggregation query - use simpler approach
let sql = "SELECT name, age FROM S3Object WHERE age >= 25";
let csv_input = CsvInput::builder().file_header_info(FileHeaderInfo::Use).build();
let input_serialization = InputSerialization::builder().csv(csv_input).build();
let csv_output = CsvOutput::builder().build();
let output_serialization = OutputSerialization::builder().csv(csv_output).build();
let response = client
.select_object_content()
.bucket(BUCKET)
.key(CSV_OBJECT)
.expression(sql)
.expression_type(ExpressionType::Sql)
.input_serialization(input_serialization)
.output_serialization(output_serialization)
.send()
.await?;
let result_str = process_select_response(response).await?;
println!("CSV Aggregation result: {result_str}");
// Verify query results - should include records with age >= 25
assert!(result_str.contains("Alice"));
assert!(result_str.contains("Bob"));
assert!(result_str.contains("Charlie"));
assert!(result_str.contains("Diana"));
assert!(result_str.contains("30"));
assert!(result_str.contains("25"));
assert!(result_str.contains("35"));
assert!(result_str.contains("28"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]
async fn test_select_object_content_json_basic() -> Result<(), Box<dyn Error>> {
let client = create_aws_s3_client().await?;
setup_test_bucket(&client).await?;
upload_test_json(&client).await?;
// Construct JSON query
let sql = "SELECT s.name, s.age FROM S3Object s WHERE s.age > 28";
let json_input = JsonInput::builder().set_type(Some(JsonType::Document)).build();
let input_serialization = InputSerialization::builder().json(json_input).build();
let json_output = JsonOutput::builder().build();
let output_serialization = OutputSerialization::builder().json(json_output).build();
let response = client
.select_object_content()
.bucket(BUCKET)
.key(JSON_OBJECT)
.expression(sql)
.expression_type(ExpressionType::Sql)
.input_serialization(input_serialization)
.output_serialization(output_serialization)
.send()
.await?;
let result_str = process_select_response(response).await?;
println!("JSON Select result: {result_str}");
// Verify JSON query results
assert!(result_str.contains("Alice"));
assert!(result_str.contains("Charlie"));
assert!(result_str.contains("30"));
assert!(result_str.contains("35"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]
async fn test_select_object_content_csv_limit() -> Result<(), Box<dyn Error>> {
let client = create_aws_s3_client().await?;
setup_test_bucket(&client).await?;
upload_test_csv(&client).await?;
// Test LIMIT clause
let sql = "SELECT * FROM S3Object LIMIT 2";
let csv_input = CsvInput::builder().file_header_info(FileHeaderInfo::Use).build();
let input_serialization = InputSerialization::builder().csv(csv_input).build();
let csv_output = CsvOutput::builder().build();
let output_serialization = OutputSerialization::builder().csv(csv_output).build();
let response = client
.select_object_content()
.bucket(BUCKET)
.key(CSV_OBJECT)
.expression(sql)
.expression_type(ExpressionType::Sql)
.input_serialization(input_serialization)
.output_serialization(output_serialization)
.send()
.await?;
let result_str = process_select_response(response).await?;
println!("CSV Limit result: {result_str}");
// Verify only first 2 records are returned
let lines: Vec<&str> = result_str.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(lines.len(), 2, "Should return exactly 2 records");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]
async fn test_select_object_content_csv_order_by() -> Result<(), Box<dyn Error>> {
let client = create_aws_s3_client().await?;
setup_test_bucket(&client).await?;
upload_test_csv(&client).await?;
// Test ORDER BY clause
let sql = "SELECT name, age FROM S3Object ORDER BY age DESC LIMIT 2";
let csv_input = CsvInput::builder().file_header_info(FileHeaderInfo::Use).build();
let input_serialization = InputSerialization::builder().csv(csv_input).build();
let csv_output = CsvOutput::builder().build();
let output_serialization = OutputSerialization::builder().csv(csv_output).build();
let response = client
.select_object_content()
.bucket(BUCKET)
.key(CSV_OBJECT)
.expression(sql)
.expression_type(ExpressionType::Sql)
.input_serialization(input_serialization)
.output_serialization(output_serialization)
.send()
.await?;
let result_str = process_select_response(response).await?;
println!("CSV Order By result: {result_str}");
// Verify ordered by age descending
let lines: Vec<&str> = result_str.lines().filter(|line| !line.trim().is_empty()).collect();
assert!(lines.len() >= 2, "Should return at least 2 records");
// Check if contains highest age records
assert!(result_str.contains("Charlie,35"));
assert!(result_str.contains("Alice,30"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]
async fn test_select_object_content_error_handling() -> Result<(), Box<dyn Error>> {
let client = create_aws_s3_client().await?;
setup_test_bucket(&client).await?;
upload_test_csv(&client).await?;
// Test invalid SQL query
let sql = "SELECT * FROM S3Object WHERE invalid_column > 10";
let csv_input = CsvInput::builder().file_header_info(FileHeaderInfo::Use).build();
let input_serialization = InputSerialization::builder().csv(csv_input).build();
let csv_output = CsvOutput::builder().build();
let output_serialization = OutputSerialization::builder().csv(csv_output).build();
// This query should fail because invalid_column doesn't exist
let result = client
.select_object_content()
.bucket(BUCKET)
.key(CSV_OBJECT)
.expression(sql)
.expression_type(ExpressionType::Sql)
.input_serialization(input_serialization)
.output_serialization(output_serialization)
.send()
.await;
// Verify query fails (expected behavior)
assert!(result.is_err(), "Query with invalid column should fail");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[serial]
#[ignore = "requires running RustFS server at localhost:9000"]
async fn test_select_object_content_nonexistent_object() -> Result<(), Box<dyn Error>> {
let client = create_aws_s3_client().await?;
setup_test_bucket(&client).await?;
// Test query on nonexistent object
let sql = "SELECT * FROM S3Object";
let csv_input = CsvInput::builder().file_header_info(FileHeaderInfo::Use).build();
let input_serialization = InputSerialization::builder().csv(csv_input).build();
let csv_output = CsvOutput::builder().build();
let output_serialization = OutputSerialization::builder().csv(csv_output).build();
let result = client
.select_object_content()
.bucket(BUCKET)
.key("nonexistent.csv")
.expression(sql)
.expression_type(ExpressionType::Sql)
.input_serialization(input_serialization)
.output_serialization(output_serialization)
.send()
.await;
// Verify query fails (expected behavior)
assert!(result.is_err(), "Query on nonexistent object should fail");
Ok(())
}

View File

@@ -1690,6 +1690,15 @@ impl DiskAPI for LocalDisk {
};
out.write_obj(&meta).await?;
objs_returned += 1;
} else {
let fpath =
self.get_object_path(&opts.bucket, path_join_buf(&[opts.base_dir.as_str(), STORAGE_FORMAT_FILE]).as_str())?;
if let Ok(meta) = tokio::fs::metadata(fpath).await
&& meta.is_file()
{
return Err(DiskError::FileNotFound);
}
}
}

View File

@@ -1,4 +1,3 @@
// #![allow(dead_code)]
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -71,11 +71,11 @@ impl NamespaceLock {
}
/// Get resource key for this namespace
fn get_resource_key(&self, resource: &str) -> String {
pub fn get_resource_key(&self, resource: &str) -> String {
format!("{}:{}", self.namespace, resource)
}
/// Acquire lock using clients
/// Acquire lock using clients with transactional semantics (all-or-nothing)
pub async fn acquire_lock(&self, request: &LockRequest) -> Result<LockResponse> {
if self.clients.is_empty() {
return Err(LockError::internal("No lock clients available"));
@@ -86,17 +86,53 @@ impl NamespaceLock {
return self.clients[0].acquire_lock(request).await;
}
// For multiple clients, try to acquire from all clients and require quorum
// Two-phase commit for distributed lock acquisition
self.acquire_lock_with_2pc(request).await
}
/// Two-phase commit lock acquisition: all nodes must succeed or all fail
async fn acquire_lock_with_2pc(&self, request: &LockRequest) -> Result<LockResponse> {
// Phase 1: Prepare - try to acquire lock on all clients
let futures: Vec<_> = self
.clients
.iter()
.map(|client| async move { client.acquire_lock(request).await })
.enumerate()
.map(|(idx, client)| async move {
let result = client.acquire_lock(request).await;
(idx, result)
})
.collect();
let results = futures::future::join_all(futures).await;
let successful = results.into_iter().filter_map(|r| r.ok()).filter(|r| r.success).count();
let mut successful_clients = Vec::new();
let mut failed_clients = Vec::new();
if successful >= self.quorum {
// Collect results
for (idx, result) in results {
match result {
Ok(response) if response.success => {
successful_clients.push(idx);
}
_ => {
failed_clients.push(idx);
}
}
}
// Check if we have enough successful acquisitions for quorum
if successful_clients.len() >= self.quorum {
// Phase 2a: Commit - we have quorum, but need to ensure consistency
// If not all clients succeeded, we need to rollback for consistency
if successful_clients.len() < self.clients.len() {
// Rollback all successful acquisitions to maintain consistency
self.rollback_acquisitions(request, &successful_clients).await;
return Ok(LockResponse::failure(
"Partial success detected, rolled back for consistency".to_string(),
Duration::ZERO,
));
}
// All clients succeeded - lock acquired successfully
Ok(LockResponse::success(
LockInfo {
id: LockId::new_deterministic(&request.resource),
@@ -114,10 +150,38 @@ impl NamespaceLock {
Duration::ZERO,
))
} else {
Ok(LockResponse::failure("Failed to acquire quorum".to_string(), Duration::ZERO))
// Phase 2b: Abort - insufficient quorum, rollback any successful acquisitions
if !successful_clients.is_empty() {
self.rollback_acquisitions(request, &successful_clients).await;
}
Ok(LockResponse::failure(
format!("Failed to acquire quorum: {}/{} required", successful_clients.len(), self.quorum),
Duration::ZERO,
))
}
}
/// Rollback lock acquisitions on specified clients
async fn rollback_acquisitions(&self, request: &LockRequest, client_indices: &[usize]) {
let lock_id = LockId::new_deterministic(&request.resource);
let rollback_futures: Vec<_> = client_indices
.iter()
.filter_map(|&idx| self.clients.get(idx))
.map(|client| async {
if let Err(e) = client.release(&lock_id).await {
tracing::warn!("Failed to rollback lock on client: {}", e);
}
})
.collect();
futures::future::join_all(rollback_futures).await;
tracing::info!(
"Rolled back {} lock acquisitions for resource: {}",
client_indices.len(),
request.resource
);
}
/// Release lock using clients
pub async fn release_lock(&self, lock_id: &LockId) -> Result<bool> {
if self.clients.is_empty() {
@@ -219,7 +283,9 @@ impl NamespaceLockManager for NamespaceLock {
return Err(LockError::internal("No lock clients available"));
}
// For each resource, create a lock request and try to acquire using clients
// Transactional batch lock: all resources must be locked or none
let mut acquired_resources = Vec::new();
for resource in resources {
let namespaced_resource = self.get_resource_key(resource);
let request = LockRequest::new(&namespaced_resource, LockType::Exclusive, owner)
@@ -227,7 +293,11 @@ impl NamespaceLockManager for NamespaceLock {
.with_ttl(ttl);
let response = self.acquire_lock(&request).await?;
if !response.success {
if response.success {
acquired_resources.push(namespaced_resource);
} else {
// Rollback all previously acquired locks
self.rollback_batch_locks(&acquired_resources, owner).await;
return Ok(false);
}
}
@@ -239,12 +309,21 @@ impl NamespaceLockManager for NamespaceLock {
return Err(LockError::internal("No lock clients available"));
}
// For each resource, create a lock ID and try to release using clients
for resource in resources {
let namespaced_resource = self.get_resource_key(resource);
let lock_id = LockId::new_deterministic(&namespaced_resource);
let _ = self.release_lock(&lock_id).await?;
}
// Release all locks (best effort)
let release_futures: Vec<_> = resources
.iter()
.map(|resource| {
let namespaced_resource = self.get_resource_key(resource);
let lock_id = LockId::new_deterministic(&namespaced_resource);
async move {
if let Err(e) = self.release_lock(&lock_id).await {
tracing::warn!("Failed to release lock for resource {}: {}", resource, e);
}
}
})
.collect();
futures::future::join_all(release_futures).await;
Ok(())
}
@@ -253,7 +332,9 @@ impl NamespaceLockManager for NamespaceLock {
return Err(LockError::internal("No lock clients available"));
}
// For each resource, create a shared lock request and try to acquire using clients
// Transactional batch read lock: all resources must be locked or none
let mut acquired_resources = Vec::new();
for resource in resources {
let namespaced_resource = self.get_resource_key(resource);
let request = LockRequest::new(&namespaced_resource, LockType::Shared, owner)
@@ -261,7 +342,11 @@ impl NamespaceLockManager for NamespaceLock {
.with_ttl(ttl);
let response = self.acquire_lock(&request).await?;
if !response.success {
if response.success {
acquired_resources.push(namespaced_resource);
} else {
// Rollback all previously acquired read locks
self.rollback_batch_locks(&acquired_resources, owner).await;
return Ok(false);
}
}
@@ -273,16 +358,45 @@ impl NamespaceLockManager for NamespaceLock {
return Err(LockError::internal("No lock clients available"));
}
// For each resource, create a lock ID and try to release using clients
for resource in resources {
let namespaced_resource = self.get_resource_key(resource);
let lock_id = LockId::new_deterministic(&namespaced_resource);
let _ = self.release_lock(&lock_id).await?;
}
// Release all read locks (best effort)
let release_futures: Vec<_> = resources
.iter()
.map(|resource| {
let namespaced_resource = self.get_resource_key(resource);
let lock_id = LockId::new_deterministic(&namespaced_resource);
async move {
if let Err(e) = self.release_lock(&lock_id).await {
tracing::warn!("Failed to release read lock for resource {}: {}", resource, e);
}
}
})
.collect();
futures::future::join_all(release_futures).await;
Ok(())
}
}
impl NamespaceLock {
/// Rollback batch lock acquisitions
async fn rollback_batch_locks(&self, acquired_resources: &[String], _owner: &str) {
let rollback_futures: Vec<_> = acquired_resources
.iter()
.map(|resource| {
let lock_id = LockId::new_deterministic(resource);
async move {
if let Err(e) = self.release_lock(&lock_id).await {
tracing::warn!("Failed to rollback lock for resource {}: {}", resource, e);
}
}
})
.collect();
futures::future::join_all(rollback_futures).await;
tracing::info!("Rolled back {} batch lock acquisitions", acquired_resources.len());
}
}
#[cfg(test)]
mod tests {
use crate::LocalClient;
@@ -343,4 +457,60 @@ mod tests {
let resource_key = ns_lock.get_resource_key("test-resource");
assert_eq!(resource_key, "test-namespace:test-resource");
}
#[tokio::test]
async fn test_transactional_batch_lock() {
let ns_lock = NamespaceLock::with_client(Arc::new(LocalClient::new()));
let resources = vec!["resource1".to_string(), "resource2".to_string(), "resource3".to_string()];
// First, acquire one of the resources to simulate conflict
let conflicting_request = LockRequest::new(ns_lock.get_resource_key("resource2"), LockType::Exclusive, "other_owner")
.with_ttl(Duration::from_secs(10));
let response = ns_lock.acquire_lock(&conflicting_request).await.unwrap();
assert!(response.success);
// Now try batch lock - should fail and rollback
let result = ns_lock
.lock_batch(&resources, "test_owner", Duration::from_millis(10), Duration::from_secs(5))
.await;
assert!(result.is_ok());
assert!(!result.unwrap()); // Should fail due to conflict
// Verify that no locks were left behind (all rolled back)
for resource in &resources {
if resource != "resource2" {
// Skip the one we intentionally locked
let check_request = LockRequest::new(ns_lock.get_resource_key(resource), LockType::Exclusive, "verify_owner")
.with_ttl(Duration::from_secs(1));
let check_response = ns_lock.acquire_lock(&check_request).await.unwrap();
assert!(check_response.success, "Resource {resource} should be available after rollback");
// Clean up
let lock_id = LockId::new_deterministic(&ns_lock.get_resource_key(resource));
let _ = ns_lock.release_lock(&lock_id).await;
}
}
}
#[tokio::test]
async fn test_distributed_lock_consistency() {
// Create a namespace with multiple local clients to simulate distributed scenario
let client1: Arc<dyn LockClient> = Arc::new(LocalClient::new());
let client2: Arc<dyn LockClient> = Arc::new(LocalClient::new());
let clients = vec![client1, client2];
let ns_lock = NamespaceLock::with_clients("test-namespace".to_string(), clients);
let request = LockRequest::new("test-resource", LockType::Exclusive, "test_owner").with_ttl(Duration::from_secs(10));
// This should succeed only if ALL clients can acquire the lock
let response = ns_lock.acquire_lock(&request).await.unwrap();
// Since we're using separate LocalClient instances, they don't share state
// so this test demonstrates the consistency check
assert!(response.success); // Either all succeed or rollback happens
}
}

View File

@@ -21,6 +21,9 @@ pub mod object_store;
pub mod query;
pub mod server;
#[cfg(test)]
mod test;
pub type QueryResult<T> = Result<T, QueryError>;
#[derive(Debug, Snafu)]
@@ -90,3 +93,82 @@ impl Display for ResolvedTable {
write!(f, "{table}")
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::common::DataFusionError;
use datafusion::sql::sqlparser::parser::ParserError;
#[test]
fn test_query_error_display() {
let err = QueryError::NotImplemented {
err: "feature X".to_string(),
};
assert_eq!(err.to_string(), "This feature is not implemented: feature X");
let err = QueryError::MultiStatement {
num: 2,
sql: "SELECT 1; SELECT 2;".to_string(),
};
assert_eq!(err.to_string(), "Multi-statement not allow, found num:2, sql:SELECT 1; SELECT 2;");
let err = QueryError::Cancel;
assert_eq!(err.to_string(), "The query has been canceled");
let err = QueryError::FunctionNotExists {
name: "my_func".to_string(),
};
assert_eq!(err.to_string(), "Udf not exists, name:my_func.");
let err = QueryError::StoreError {
e: "connection failed".to_string(),
};
assert_eq!(err.to_string(), "Store Error, e:connection failed.");
}
#[test]
fn test_query_error_from_datafusion_error() {
let df_error = DataFusionError::Plan("invalid plan".to_string());
let query_error: QueryError = df_error.into();
match query_error {
QueryError::Datafusion { source, .. } => {
assert!(source.to_string().contains("invalid plan"));
}
_ => panic!("Expected Datafusion error"),
}
}
#[test]
fn test_query_error_from_parser_error() {
let parser_error = ParserError::ParserError("syntax error".to_string());
let query_error = QueryError::Parser { source: parser_error };
assert!(query_error.to_string().contains("syntax error"));
}
#[test]
fn test_resolved_table() {
let table = ResolvedTable {
table: "my_table".to_string(),
};
assert_eq!(table.table(), "my_table");
assert_eq!(table.to_string(), "my_table");
}
#[test]
fn test_resolved_table_clone_and_eq() {
let table1 = ResolvedTable {
table: "table1".to_string(),
};
let table2 = table1.clone();
let table3 = ResolvedTable {
table: "table2".to_string(),
};
assert_eq!(table1, table2);
assert_ne!(table1, table3);
}
}

View File

@@ -0,0 +1,17 @@
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Test modules for s3select-api
pub mod query_execution_test;

View File

@@ -0,0 +1,167 @@
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(test)]
mod tests {
use crate::query::execution::{DONE, Output, QueryExecution, QueryState, QueryType, RUNNING};
use crate::{QueryError, QueryResult};
use async_trait::async_trait;
#[test]
fn test_query_type_display() {
assert_eq!(format!("{}", QueryType::Batch), "batch");
assert_eq!(format!("{}", QueryType::Stream), "stream");
}
#[test]
fn test_query_type_equality() {
assert_eq!(QueryType::Batch, QueryType::Batch);
assert_ne!(QueryType::Batch, QueryType::Stream);
assert_eq!(QueryType::Stream, QueryType::Stream);
}
#[tokio::test]
async fn test_output_nil_methods() {
let output = Output::Nil(());
let result = output.chunk_result().await;
assert!(result.is_ok(), "Output::Nil result should be Ok");
let output2 = Output::Nil(());
let rows = output2.num_rows().await;
assert_eq!(rows, 0, "Output::Nil should have 0 rows");
let output3 = Output::Nil(());
let affected = output3.affected_rows().await;
assert_eq!(affected, 0, "Output::Nil should have 0 affected rows");
}
#[test]
fn test_query_state_as_ref() {
let accepting = QueryState::ACCEPTING;
assert_eq!(accepting.as_ref(), "ACCEPTING");
let running = QueryState::RUNNING(RUNNING::ANALYZING);
assert_eq!(running.as_ref(), "ANALYZING");
let done = QueryState::DONE(DONE::FINISHED);
assert_eq!(done.as_ref(), "FINISHED");
}
#[test]
fn test_running_state_as_ref() {
assert_eq!(RUNNING::DISPATCHING.as_ref(), "DISPATCHING");
assert_eq!(RUNNING::ANALYZING.as_ref(), "ANALYZING");
assert_eq!(RUNNING::OPTMIZING.as_ref(), "OPTMIZING");
assert_eq!(RUNNING::SCHEDULING.as_ref(), "SCHEDULING");
}
#[test]
fn test_done_state_as_ref() {
assert_eq!(DONE::FINISHED.as_ref(), "FINISHED");
assert_eq!(DONE::FAILED.as_ref(), "FAILED");
assert_eq!(DONE::CANCELLED.as_ref(), "CANCELLED");
}
// Mock implementation for testing
struct MockQueryExecution {
should_succeed: bool,
should_cancel: bool,
}
#[async_trait]
impl QueryExecution for MockQueryExecution {
async fn start(&self) -> QueryResult<Output> {
if self.should_cancel {
return Err(QueryError::Cancel);
}
if self.should_succeed {
Ok(Output::Nil(()))
} else {
Err(QueryError::NotImplemented {
err: "Mock execution failed".to_string(),
})
}
}
fn cancel(&self) -> QueryResult<()> {
Ok(())
}
}
#[tokio::test]
async fn test_mock_query_execution_success() {
let execution = MockQueryExecution {
should_succeed: true,
should_cancel: false,
};
let result = execution.start().await;
assert!(result.is_ok(), "Mock execution should succeed");
if let Ok(Output::Nil(_)) = result {
// Expected result
} else {
panic!("Expected Output::Nil");
}
}
#[tokio::test]
async fn test_mock_query_execution_failure() {
let execution = MockQueryExecution {
should_succeed: false,
should_cancel: false,
};
let result = execution.start().await;
assert!(result.is_err(), "Mock execution should fail");
if let Err(QueryError::NotImplemented { .. }) = result {
// Expected error
} else {
panic!("Expected NotImplemented error");
}
}
#[tokio::test]
async fn test_mock_query_execution_cancel() {
let execution = MockQueryExecution {
should_succeed: false,
should_cancel: true,
};
let result = execution.start().await;
assert!(result.is_err(), "Cancelled execution should fail");
if let Err(QueryError::Cancel) = result {
// Expected cancellation error
} else {
panic!("Expected Cancel error");
}
let cancel_result = execution.cancel();
assert!(cancel_result.is_ok(), "Cancel should succeed");
}
#[test]
fn test_query_execution_default_type() {
let execution = MockQueryExecution {
should_succeed: true,
should_cancel: false,
};
assert_eq!(execution.query_type(), QueryType::Batch);
}
}

View File

@@ -107,27 +107,51 @@ pub async fn make_rustfsms(input: Arc<SelectObjectContentInput>, is_test: bool)
Ok(db_server)
}
pub async fn make_rustfsms_with_components(
input: Arc<SelectObjectContentInput>,
is_test: bool,
func_manager: Arc<SimpleFunctionMetadataManager>,
parser: Arc<DefaultParser>,
query_execution_factory: Arc<SqlQueryExecutionFactory>,
default_table_provider: Arc<BaseTableProvider>,
) -> QueryResult<impl DatabaseManagerSystem> {
// TODO session config need load global system config
let session_factory = Arc::new(SessionCtxFactory { is_test });
let query_dispatcher = SimpleQueryDispatcherBuilder::default()
.with_input(input)
.with_func_manager(func_manager)
.with_default_table_provider(default_table_provider)
.with_session_factory(session_factory)
.with_parser(parser)
.with_query_execution_factory(query_execution_factory)
.build()?;
let mut builder = RustFSmsBuilder::default();
let db_server = builder.query_dispatcher(query_dispatcher).build().expect("build db server");
Ok(db_server)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion::{arrow::util::pretty, assert_batches_eq};
use rustfs_s3select_api::{
query::{Context, Query},
server::dbms::DatabaseManagerSystem,
};
use rustfs_s3select_api::query::{Context, Query};
use s3s::dto::{
CSVInput, CSVOutput, ExpressionType, FieldDelimiter, FileHeaderInfo, InputSerialization, OutputSerialization,
RecordDelimiter, SelectObjectContentInput, SelectObjectContentRequest,
};
use crate::instance::make_rustfsms;
use crate::get_global_db;
#[tokio::test]
#[ignore]
async fn test_simple_sql() {
let sql = "select * from S3Object";
let input = Arc::new(SelectObjectContentInput {
let input = SelectObjectContentInput {
bucket: "dandan".to_string(),
expected_bucket_owner: None,
key: "test.csv".to_string(),
@@ -151,9 +175,9 @@ mod tests {
request_progress: None,
scan_range: None,
},
});
let db = make_rustfsms(input.clone(), true).await.unwrap();
let query = Query::new(Context { input }, sql.to_string());
};
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await.unwrap();
@@ -184,7 +208,7 @@ mod tests {
#[ignore]
async fn test_func_sql() {
let sql = "SELECT * FROM S3Object s";
let input = Arc::new(SelectObjectContentInput {
let input = SelectObjectContentInput {
bucket: "dandan".to_string(),
expected_bucket_owner: None,
key: "test.csv".to_string(),
@@ -210,9 +234,9 @@ mod tests {
request_progress: None,
scan_range: None,
},
});
let db = make_rustfsms(input.clone(), true).await.unwrap();
let query = Query::new(Context { input }, sql.to_string());
};
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await.unwrap();

View File

@@ -19,3 +19,84 @@ pub mod function;
pub mod instance;
pub mod metadata;
pub mod sql;
#[cfg(test)]
mod test;
use rustfs_s3select_api::{QueryResult, server::dbms::DatabaseManagerSystem};
use s3s::dto::SelectObjectContentInput;
use std::sync::{Arc, LazyLock};
use crate::{
execution::{factory::SqlQueryExecutionFactory, scheduler::local::LocalScheduler},
function::simple_func_manager::SimpleFunctionMetadataManager,
metadata::base_table::BaseTableProvider,
sql::{optimizer::CascadeOptimizerBuilder, parser::DefaultParser},
};
// Global cached components that can be reused across database instances
struct GlobalComponents {
func_manager: Arc<SimpleFunctionMetadataManager>,
parser: Arc<DefaultParser>,
query_execution_factory: Arc<SqlQueryExecutionFactory>,
default_table_provider: Arc<BaseTableProvider>,
}
static GLOBAL_COMPONENTS: LazyLock<GlobalComponents> = LazyLock::new(|| {
let func_manager = Arc::new(SimpleFunctionMetadataManager::default());
let parser = Arc::new(DefaultParser::default());
let optimizer = Arc::new(CascadeOptimizerBuilder::default().build());
let scheduler = Arc::new(LocalScheduler {});
let query_execution_factory = Arc::new(SqlQueryExecutionFactory::new(optimizer, scheduler));
let default_table_provider = Arc::new(BaseTableProvider::default());
GlobalComponents {
func_manager,
parser,
query_execution_factory,
default_table_provider,
}
});
/// Get or create database instance with cached components
pub async fn get_global_db(
input: SelectObjectContentInput,
enable_debug: bool,
) -> QueryResult<Arc<dyn DatabaseManagerSystem + Send + Sync>> {
let components = &*GLOBAL_COMPONENTS;
let db = crate::instance::make_rustfsms_with_components(
Arc::new(input),
enable_debug,
components.func_manager.clone(),
components.parser.clone(),
components.query_execution_factory.clone(),
components.default_table_provider.clone(),
)
.await?;
Ok(Arc::new(db) as Arc<dyn DatabaseManagerSystem + Send + Sync>)
}
/// Create a fresh database instance without using cached components (for testing)
pub async fn create_fresh_db() -> QueryResult<Arc<dyn DatabaseManagerSystem + Send + Sync>> {
// Create a default test input for fresh database creation
let default_input = SelectObjectContentInput {
bucket: "test-bucket".to_string(),
expected_bucket_owner: None,
key: "test.csv".to_string(),
sse_customer_algorithm: None,
sse_customer_key: None,
sse_customer_key_md5: None,
request: s3s::dto::SelectObjectContentRequest {
expression: "SELECT * FROM S3Object".to_string(),
expression_type: s3s::dto::ExpressionType::from_static("SQL"),
input_serialization: s3s::dto::InputSerialization::default(),
output_serialization: s3s::dto::OutputSerialization::default(),
request_progress: None,
scan_range: None,
},
};
let db = crate::instance::make_rustfsms(Arc::new(default_input), true).await?;
Ok(Arc::new(db) as Arc<dyn DatabaseManagerSystem + Send + Sync>)
}

View File

@@ -0,0 +1,247 @@
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(test)]
mod error_handling_tests {
use crate::get_global_db;
use rustfs_s3select_api::{
QueryError,
query::{Context, Query},
};
use s3s::dto::{
CSVInput, ExpressionType, FileHeaderInfo, InputSerialization, SelectObjectContentInput, SelectObjectContentRequest,
};
use std::sync::Arc;
fn create_test_input_with_sql(sql: &str) -> SelectObjectContentInput {
SelectObjectContentInput {
bucket: "test-bucket".to_string(),
expected_bucket_owner: None,
key: "test.csv".to_string(),
sse_customer_algorithm: None,
sse_customer_key: None,
sse_customer_key_md5: None,
request: SelectObjectContentRequest {
expression: sql.to_string(),
expression_type: ExpressionType::from_static("SQL"),
input_serialization: InputSerialization {
csv: Some(CSVInput {
file_header_info: Some(FileHeaderInfo::from_static(FileHeaderInfo::USE)),
..Default::default()
}),
..Default::default()
},
output_serialization: s3s::dto::OutputSerialization::default(),
request_progress: None,
scan_range: None,
},
}
}
#[tokio::test]
async fn test_syntax_error_handling() {
let invalid_sqls = vec![
"INVALID SQL",
"SELECT FROM",
"SELECT * FORM S3Object", // typo in FROM
"SELECT * FROM",
"SELECT * FROM S3Object WHERE",
"SELECT COUNT( FROM S3Object", // missing closing parenthesis
];
for sql in invalid_sqls {
let input = create_test_input_with_sql(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_err(), "Expected error for SQL: {sql}");
}
}
#[tokio::test]
async fn test_multi_statement_error() {
let multi_statement_sqls = vec![
"SELECT * FROM S3Object; SELECT 1;",
"SELECT 1; SELECT 2; SELECT 3;",
"SELECT * FROM S3Object; DROP TABLE test;",
];
for sql in multi_statement_sqls {
let input = create_test_input_with_sql(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_err(), "Expected multi-statement error for SQL: {sql}");
if let Err(QueryError::MultiStatement { num, .. }) = result {
assert!(num >= 2, "Expected at least 2 statements, got: {num}");
}
}
}
#[tokio::test]
async fn test_unsupported_operations() {
let unsupported_sqls = vec![
"INSERT INTO S3Object VALUES (1, 'test')",
"UPDATE S3Object SET name = 'test'",
"DELETE FROM S3Object",
"CREATE TABLE test (id INT)",
"DROP TABLE S3Object",
];
for sql in unsupported_sqls {
let input = create_test_input_with_sql(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
// These should either fail with syntax error or not implemented error
assert!(result.is_err(), "Expected error for unsupported SQL: {sql}");
}
}
#[tokio::test]
async fn test_invalid_column_references() {
let invalid_column_sqls = vec![
"SELECT nonexistent_column FROM S3Object",
"SELECT * FROM S3Object WHERE nonexistent_column = 1",
"SELECT * FROM S3Object ORDER BY nonexistent_column",
"SELECT * FROM S3Object GROUP BY nonexistent_column",
];
for sql in invalid_column_sqls {
let input = create_test_input_with_sql(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
// These might succeed or fail depending on schema inference
// The test verifies that the system handles them gracefully
match result {
Ok(_) => {
// If it succeeds, verify we can get results
let handle = result.unwrap();
let output = handle.result().chunk_result().await;
// Should either succeed with empty results or fail gracefully
let _ = output;
}
Err(_) => {
// Expected to fail - this is acceptable
}
}
}
}
#[tokio::test]
async fn test_complex_query_error_recovery() {
let complex_invalid_sql = r#"
SELECT
name,
age,
INVALID_FUNCTION(salary) as invalid_calc,
department
FROM S3Object
WHERE age > 'invalid_number'
GROUP BY department, nonexistent_column
HAVING COUNT(*) > INVALID_FUNCTION()
ORDER BY invalid_column
"#;
let input = create_test_input_with_sql(complex_invalid_sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, complex_invalid_sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_err(), "Expected error for complex invalid SQL");
}
#[tokio::test]
async fn test_empty_query() {
let empty_sqls = vec!["", " ", "\n\t \n"];
for sql in empty_sqls {
let input = create_test_input_with_sql(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
// Empty queries might be handled differently by the parser
match result {
Ok(_) => {
// Some parsers might accept empty queries
}
Err(_) => {
// Expected to fail for empty SQL
}
}
}
}
#[tokio::test]
async fn test_very_long_query() {
// Create a very long but valid query
let mut long_sql = "SELECT ".to_string();
for i in 0..1000 {
if i > 0 {
long_sql.push_str(", ");
}
long_sql.push_str(&format!("'column_{i}' as col_{i}"));
}
long_sql.push_str(" FROM S3Object LIMIT 1");
let input = create_test_input_with_sql(&long_sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, long_sql);
let result = db.execute(&query).await;
// This should either succeed or fail gracefully
match result {
Ok(handle) => {
let output = handle.result().chunk_result().await;
assert!(output.is_ok(), "Query execution should complete successfully");
}
Err(_) => {
// Acceptable to fail due to resource constraints
}
}
}
#[tokio::test]
async fn test_sql_injection_patterns() {
let injection_patterns = vec![
"SELECT * FROM S3Object WHERE name = 'test'; DROP TABLE users; --",
"SELECT * FROM S3Object UNION SELECT * FROM information_schema.tables",
"SELECT * FROM S3Object WHERE 1=1 OR 1=1",
];
for sql in injection_patterns {
let input = create_test_input_with_sql(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
// These should be handled safely - either succeed with limited scope or fail
match result {
Ok(_) => {
// If successful, it should only access S3Object data
}
Err(_) => {
// Expected to fail for security reasons
}
}
}
}
}

View File

@@ -0,0 +1,228 @@
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(test)]
mod integration_tests {
use crate::{create_fresh_db, get_global_db, instance::make_rustfsms};
use rustfs_s3select_api::{
QueryError,
query::{Context, Query},
};
use s3s::dto::{
CSVInput, CSVOutput, ExpressionType, FileHeaderInfo, InputSerialization, OutputSerialization, SelectObjectContentInput,
SelectObjectContentRequest,
};
use std::sync::Arc;
fn create_test_input(sql: &str) -> SelectObjectContentInput {
SelectObjectContentInput {
bucket: "test-bucket".to_string(),
expected_bucket_owner: None,
key: "test.csv".to_string(),
sse_customer_algorithm: None,
sse_customer_key: None,
sse_customer_key_md5: None,
request: SelectObjectContentRequest {
expression: sql.to_string(),
expression_type: ExpressionType::from_static("SQL"),
input_serialization: InputSerialization {
csv: Some(CSVInput {
file_header_info: Some(FileHeaderInfo::from_static(FileHeaderInfo::USE)),
..Default::default()
}),
..Default::default()
},
output_serialization: OutputSerialization {
csv: Some(CSVOutput::default()),
..Default::default()
},
request_progress: None,
scan_range: None,
},
}
}
#[tokio::test]
async fn test_database_creation() {
let input = create_test_input("SELECT * FROM S3Object");
let result = make_rustfsms(Arc::new(input), true).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_global_db_creation() {
let input = create_test_input("SELECT * FROM S3Object");
let result = get_global_db(input.clone(), true).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_fresh_db_creation() {
let result = create_fresh_db().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_simple_select_query() {
let sql = "SELECT * FROM S3Object";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_ok());
let query_handle = result.unwrap();
let output = query_handle.result().chunk_result().await;
assert!(output.is_ok());
}
#[tokio::test]
async fn test_select_with_where_clause() {
let sql = "SELECT name, age FROM S3Object WHERE age > 30";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_select_with_aggregation() {
let sql = "SELECT department, COUNT(*) as count FROM S3Object GROUP BY department";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
// Aggregation queries might fail due to lack of actual data, which is acceptable
match result {
Ok(_) => {
// If successful, that's great
}
Err(_) => {
// Expected to fail due to no actual data source
}
}
}
#[tokio::test]
async fn test_invalid_sql_syntax() {
let sql = "INVALID SQL SYNTAX";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_multi_statement_error() {
let sql = "SELECT * FROM S3Object; SELECT 1;";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_err());
if let Err(QueryError::MultiStatement { num, .. }) = result {
assert_eq!(num, 2);
} else {
panic!("Expected MultiStatement error");
}
}
#[tokio::test]
async fn test_query_state_machine_workflow() {
let sql = "SELECT * FROM S3Object";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
// Test state machine creation
let state_machine = db.build_query_state_machine(query.clone()).await;
assert!(state_machine.is_ok());
let state_machine = state_machine.unwrap();
// Test logical plan building
let logical_plan = db.build_logical_plan(state_machine.clone()).await;
assert!(logical_plan.is_ok());
// Test execution if plan exists
if let Ok(Some(plan)) = logical_plan {
let execution_result = db.execute_logical_plan(plan, state_machine).await;
assert!(execution_result.is_ok());
}
}
#[tokio::test]
async fn test_query_with_limit() {
let sql = "SELECT * FROM S3Object LIMIT 5";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_ok());
let query_handle = result.unwrap();
let output = query_handle.result().chunk_result().await.unwrap();
// Verify that we get results (exact count depends on test data)
let total_rows: usize = output.iter().map(|batch| batch.num_rows()).sum();
assert!(total_rows <= 5);
}
#[tokio::test]
async fn test_query_with_order_by() {
let sql = "SELECT name, age FROM S3Object ORDER BY age DESC";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
let query = Query::new(Context { input: Arc::new(input) }, sql.to_string());
let result = db.execute(&query).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_concurrent_queries() {
let sql = "SELECT * FROM S3Object";
let input = create_test_input(sql);
let db = get_global_db(input.clone(), true).await.unwrap();
// Execute multiple queries concurrently
let mut handles = vec![];
for i in 0..3 {
let query = Query::new(
Context {
input: Arc::new(input.clone()),
},
format!("SELECT * FROM S3Object LIMIT {}", i + 1),
);
let db_clone = db.clone();
let handle = tokio::spawn(async move { db_clone.execute(&query).await });
handles.push(handle);
}
// Wait for all queries to complete
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok());
}
}
}

View File

@@ -0,0 +1,18 @@
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Test modules for s3select-query
pub mod error_handling_test;
pub mod integration_test;

View File

@@ -32,7 +32,7 @@ use rustfs_ecstore::set_disk::MAX_PARTS_COUNT;
use rustfs_s3select_api::object_store::bytes_stream;
use rustfs_s3select_api::query::Context;
use rustfs_s3select_api::query::Query;
use rustfs_s3select_api::server::dbms::DatabaseManagerSystem;
use rustfs_s3select_query::get_global_db;
// use rustfs_ecstore::store_api::RESERVED_METADATA_PREFIX;
use futures::StreamExt;
@@ -86,7 +86,6 @@ use rustfs_rio::EtagReader;
use rustfs_rio::HashReader;
use rustfs_rio::Reader;
use rustfs_rio::WarpReader;
use rustfs_s3select_query::instance::make_rustfsms;
use rustfs_utils::CompressionAlgorithm;
use rustfs_utils::path::path_join_buf;
use rustfs_zip::CompressionFormat;
@@ -2674,8 +2673,8 @@ impl S3 for FS {
let input = Arc::new(req.input);
info!("{:?}", input);
let db = make_rustfsms(input.clone(), false).await.map_err(|e| {
error!("make db failed, {}", e.to_string());
let db = get_global_db((*input).clone(), false).await.map_err(|e| {
error!("get global db failed, {}", e.to_string());
s3_error!(InternalError, "{}", e.to_string())
})?;
let query = Query::new(Context { input: input.clone() }, input.request.expression.clone());

320
scripts/run_e2e_tests.sh Executable file
View File

@@ -0,0 +1,320 @@
#!/bin/bash
# E2E Test Runner Script
# Automatically starts RustFS instance, runs tests, and cleans up
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Default values
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
TARGET_DIR="$PROJECT_ROOT/target/debug"
RUSTFS_BINARY="$TARGET_DIR/rustfs"
DATA_DIR="$TARGET_DIR/rustfs_test_data"
RUSTFS_PID=""
TEST_FILTER=""
TEST_TYPE="all"
# Function to print colored output
print_info() {
echo -e "${BLUE}[INFO]${NC} $1"
}
print_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
print_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Function to show usage
show_usage() {
cat << EOF
Usage: $0 [OPTIONS]
Options:
-h, --help Show this help message
-t, --test <pattern> Run specific test(s) matching pattern
-f, --file <file> Run all tests in specific file (e.g., sql, basic)
-a, --all Run all e2e tests (default)
Examples:
$0 # Run all e2e tests
$0 -t test_select_object_content_csv_basic # Run specific test
$0 -f sql # Run all SQL tests
$0 -f reliant::sql # Run all tests in sql.rs file
EOF
}
# Function to cleanup on exit
cleanup() {
print_info "Cleaning up..."
# Stop RustFS if running
if [ ! -z "$RUSTFS_PID" ] && kill -0 "$RUSTFS_PID" 2>/dev/null; then
print_info "Stopping RustFS (PID: $RUSTFS_PID)..."
kill "$RUSTFS_PID" 2>/dev/null || true
sleep 2
# Force kill if still running
if kill -0 "$RUSTFS_PID" 2>/dev/null; then
print_warning "Force killing RustFS..."
kill -9 "$RUSTFS_PID" 2>/dev/null || true
fi
fi
# Clean up data directory
if [ -d "$DATA_DIR" ]; then
print_info "Removing test data directory: $DATA_DIR"
rm -rf "$DATA_DIR"
fi
print_success "Cleanup completed"
}
# Set trap to cleanup on exit
trap cleanup EXIT INT TERM
# Function to build RustFS
build_rustfs() {
print_info "Building RustFS..."
cd "$PROJECT_ROOT"
if ! cargo build --bin rustfs; then
print_error "Failed to build RustFS"
exit 1
fi
if [ ! -f "$RUSTFS_BINARY" ]; then
print_error "RustFS binary not found at: $RUSTFS_BINARY"
exit 1
fi
print_success "RustFS built successfully"
}
# Function to check if required tools are available
check_dependencies() {
local missing_tools=()
if ! command -v curl >/dev/null 2>&1; then
missing_tools+=("curl")
fi
if ! command -v cargo >/dev/null 2>&1; then
missing_tools+=("cargo")
fi
if [ ${#missing_tools[@]} -gt 0 ]; then
print_error "Missing required tools: ${missing_tools[*]}"
print_error "Please install the missing tools and try again"
exit 1
fi
}
# Function to start RustFS
start_rustfs() {
print_info "Starting RustFS instance..."
# Create data directory and logs directory
mkdir -p "$DATA_DIR"
mkdir -p "$TARGET_DIR/logs"
# Start RustFS in background with environment variables
cd "$TARGET_DIR"
RUSTFS_ACCESS_KEY=rustfsadmin RUSTFS_SECRET_KEY=rustfsadmin \
RUSTFS_OBS_LOG_DIRECTORY="$TARGET_DIR/logs" \
./rustfs --address :9000 "$DATA_DIR" > rustfs.log 2>&1 &
RUSTFS_PID=$!
print_info "RustFS started with PID: $RUSTFS_PID"
print_info "Data directory: $DATA_DIR"
print_info "Log file: $TARGET_DIR/rustfs.log"
print_info "RustFS logs directory: $TARGET_DIR/logs"
# Wait for RustFS to be ready
print_info "Waiting for RustFS to be ready..."
local max_attempts=15 # Reduced from 30 to 15 seconds
local attempt=0
while [ $attempt -lt $max_attempts ]; do
# Check if process is still running first (faster check)
if ! kill -0 "$RUSTFS_PID" 2>/dev/null; then
print_error "RustFS process died unexpectedly"
print_error "Log output:"
cat "$TARGET_DIR/rustfs.log" || true
exit 1
fi
# Try simple HTTP connection first (most reliable)
if curl -s --noproxy localhost --connect-timeout 2 --max-time 3 "http://localhost:9000/" >/dev/null 2>&1; then
print_success "RustFS is ready!"
return 0
fi
# Try health endpoint if available
if curl -s --noproxy localhost --connect-timeout 2 --max-time 3 "http://localhost:9000/health" >/dev/null 2>&1; then
print_success "RustFS is ready!"
return 0
fi
# Try port connectivity check (faster than HTTP)
if nc -z localhost 9000 2>/dev/null; then
print_info "Port 9000 is open, verifying HTTP response..."
if curl -s --noproxy localhost --connect-timeout 1 --max-time 2 "http://localhost:9000/" >/dev/null 2>&1; then
print_success "RustFS is ready!"
return 0
fi
fi
sleep 1
attempt=$((attempt + 1))
echo -n "."
done
echo
print_warning "RustFS health check failed within $max_attempts seconds"
print_info "Checking if RustFS process is still running..."
if kill -0 "$RUSTFS_PID" 2>/dev/null; then
print_info "RustFS process is still running (PID: $RUSTFS_PID)"
print_info "Trying final connection attempts..."
# Quick final attempts with shorter timeouts
for i in 1 2 3; do
if curl -s --noproxy localhost --connect-timeout 1 --max-time 2 "http://localhost:9000/" >/dev/null 2>&1; then
print_success "RustFS is now ready!"
return 0
fi
if nc -z localhost 9000 2>/dev/null; then
print_info "Port 9000 is accessible, continuing with tests..."
return 0
fi
sleep 1
done
print_warning "RustFS may be slow to respond, but process is running"
print_info "Continuing with tests anyway..."
return 0
else
print_error "RustFS process has died"
print_error "Log output:"
cat "$TARGET_DIR/rustfs.log" || true
return 1
fi
}
# Function to run tests
run_tests() {
print_info "Running e2e tests..."
cd "$PROJECT_ROOT"
local test_cmd="cargo test --package e2e_test --lib"
case "$TEST_TYPE" in
"specific")
test_cmd="$test_cmd -- $TEST_FILTER --exact --show-output --ignored"
print_info "Running specific test: $TEST_FILTER"
;;
"file")
test_cmd="$test_cmd -- $TEST_FILTER --show-output --ignored"
print_info "Running tests in file/module: $TEST_FILTER"
;;
"all")
test_cmd="$test_cmd -- --show-output --ignored"
print_info "Running all e2e tests"
;;
esac
print_info "Test command: $test_cmd"
if eval "$test_cmd"; then
print_success "All tests passed!"
return 0
else
print_error "Some tests failed!"
return 1
fi
}
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_usage
exit 0
;;
-t|--test)
TEST_FILTER="$2"
TEST_TYPE="specific"
shift 2
;;
-f|--file)
TEST_FILTER="$2"
TEST_TYPE="file"
shift 2
;;
-a|--all)
TEST_TYPE="all"
shift
;;
*)
print_error "Unknown option: $1"
show_usage
exit 1
;;
esac
done
# Main execution
main() {
print_info "Starting E2E Test Runner"
print_info "Project root: $PROJECT_ROOT"
print_info "Target directory: $TARGET_DIR"
# Check dependencies
check_dependencies
# Build RustFS
build_rustfs
# Start RustFS
if ! start_rustfs; then
print_error "Failed to start RustFS properly"
print_info "Checking if we can still run tests..."
if [ ! -z "$RUSTFS_PID" ] && kill -0 "$RUSTFS_PID" 2>/dev/null; then
print_info "RustFS process is still running, attempting to continue..."
else
print_error "RustFS is not running, cannot proceed with tests"
exit 1
fi
fi
# Run tests
local test_result=0
run_tests || test_result=$?
# Cleanup will be handled by trap
if [ $test_result -eq 0 ]; then
print_success "E2E tests completed successfully!"
exit 0
else
print_error "E2E tests failed!"
exit 1
fi
}
# Run main function
main "$@"