Merge pull request #284 from rustfs/feature/s3select

support func
This commit is contained in:
junxiangMu
2025-04-01 11:55:14 +08:00
committed by GitHub
5 changed files with 110 additions and 33 deletions

View File

@@ -1,4 +1,3 @@
use std::collections::HashSet;
use std::sync::Arc;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
@@ -7,11 +6,11 @@ use crate::QueryResult;
pub type FuncMetaManagerRef = Arc<dyn FunctionMetadataManager + Send + Sync>;
pub trait FunctionMetadataManager {
fn register_udf(&mut self, udf: ScalarUDF) -> QueryResult<()>;
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> QueryResult<()>;
fn register_udaf(&mut self, udaf: AggregateUDF) -> QueryResult<()>;
fn register_udaf(&mut self, udaf: Arc<AggregateUDF>) -> QueryResult<()>;
fn register_udwf(&mut self, udwf: WindowUDF) -> QueryResult<()>;
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> QueryResult<()>;
fn udf(&self, name: &str) -> QueryResult<Arc<ScalarUDF>>;
@@ -19,5 +18,7 @@ pub trait FunctionMetadataManager {
fn udwf(&self, name: &str) -> QueryResult<Arc<WindowUDF>>;
fn udfs(&self) -> HashSet<String>;
fn udfs(&self) -> Vec<String>;
fn udafs(&self) -> Vec<String>;
fn udwfs(&self) -> Vec<String>;
}

View File

@@ -139,7 +139,7 @@ impl SimpleQueryDispatcher {
let path = format!("s3://{}/{}", self.input.bucket, self.input.key);
let table_path = ListingTableUrl::parse(path)?;
let listing_options = if self.input.request.input_serialization.csv.is_some() {
let file_format = CsvFormat::default().with_options(CsvOptions::default().with_has_header(false));
let file_format = CsvFormat::default().with_options(CsvOptions::default().with_has_header(true));
ListingOptions::new(Arc::new(file_format)).with_file_extension(".csv")
} else if self.input.request.input_serialization.parquet.is_some() {
let file_format = ParquetFormat::new();

View File

@@ -1,13 +1,15 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::sync::Arc;
use api::query::function::FunctionMetadataManager;
use api::{QueryError, QueryResult};
use datafusion::execution::SessionStateDefaults;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use tracing::debug;
pub type SimpleFunctionMetadataManagerRef = Arc<SimpleFunctionMetadataManager>;
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct SimpleFunctionMetadataManager {
/// Scalar functions that are registered with the context
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
@@ -17,19 +19,53 @@ pub struct SimpleFunctionMetadataManager {
pub window_functions: HashMap<String, Arc<WindowUDF>>,
}
impl Default for SimpleFunctionMetadataManager {
fn default() -> Self {
let mut func_meta_manager = Self {
scalar_functions: Default::default(),
aggregate_functions: Default::default(),
window_functions: Default::default(),
};
SessionStateDefaults::default_scalar_functions().into_iter().for_each(|udf| {
let existing_udf = func_meta_manager.register_udf(udf.clone());
if let Ok(()) = existing_udf {
debug!("Overwrote an existing UDF: {}", udf.name());
}
});
SessionStateDefaults::default_aggregate_functions()
.into_iter()
.for_each(|udaf| {
let existing_udaf = func_meta_manager.register_udaf(udaf.clone());
if let Ok(()) = existing_udaf {
debug!("Overwrote an existing UDAF: {}", udaf.name());
}
});
SessionStateDefaults::default_window_functions().into_iter().for_each(|udwf| {
let existing_udwf = func_meta_manager.register_udwf(udwf.clone());
if let Ok(()) = existing_udwf {
debug!("Overwrote an existing UDWF: {}", udwf.name());
}
});
func_meta_manager
}
}
impl FunctionMetadataManager for SimpleFunctionMetadataManager {
fn register_udf(&mut self, f: ScalarUDF) -> QueryResult<()> {
self.scalar_functions.insert(f.inner().name().to_uppercase(), Arc::new(f));
fn register_udf(&mut self, f: Arc<ScalarUDF>) -> QueryResult<()> {
self.scalar_functions.insert(f.inner().name().to_uppercase(), f);
Ok(())
}
fn register_udaf(&mut self, f: AggregateUDF) -> QueryResult<()> {
self.aggregate_functions.insert(f.inner().name().to_uppercase(), Arc::new(f));
fn register_udaf(&mut self, f: Arc<AggregateUDF>) -> QueryResult<()> {
self.aggregate_functions.insert(f.inner().name().to_uppercase(), f);
Ok(())
}
fn register_udwf(&mut self, f: WindowUDF) -> QueryResult<()> {
self.window_functions.insert(f.inner().name().to_uppercase(), Arc::new(f));
fn register_udwf(&mut self, f: Arc<WindowUDF>) -> QueryResult<()> {
self.window_functions.insert(f.inner().name().to_uppercase(), f);
Ok(())
}
@@ -57,7 +93,13 @@ impl FunctionMetadataManager for SimpleFunctionMetadataManager {
.ok_or_else(|| QueryError::FunctionNotExists { name: name.to_string() })
}
fn udfs(&self) -> HashSet<String> {
fn udfs(&self) -> Vec<String> {
self.scalar_functions.keys().cloned().collect()
}
fn udafs(&self) -> Vec<String> {
self.aggregate_functions.keys().cloned().collect()
}
fn udwfs(&self) -> Vec<String> {
self.window_functions.keys().cloned().collect()
}
}

View File

@@ -141,24 +141,58 @@ mod tests {
let results = result.result().chunk_result().await.unwrap().to_vec();
let expected = [
"+----------------+----------+----------+------------+----------+",
"| column_1 | column_2 | column_3 | column_4 | column_5 |",
"+----------------+----------+----------+------------+----------+",
"| id | name | age | department | salary |",
"| 1 | Alice | 25 | HR | 5000 |",
"| 2 | Bob | 30 | IT | 6000 |",
"| 3 | Charlie | 35 | Finance | 7000 |",
"| 4 | Diana | 22 | Marketing | 4500 |",
"| 5 | Eve | 28 | IT | 5500 |",
"| 6 | Frank | 40 | Finance | 8000 |",
"| 7 | Grace | 26 | HR | 5200 |",
"| 8 | Henry | 32 | IT | 6200 |",
"| 9 | Ivy | 24 | Marketing | 4800 |",
"| 10 | Jack | 38 | Finance | 7500 |",
"+----------------+----------+----------+------------+----------+",
"+----------------+---------+-----+------------+--------+",
"| id | name | age | department | salary |",
"+----------------+---------+-----+------------+--------+",
"| 1 | Alice | 25 | HR | 5000 |",
"| 2 | Bob | 30 | IT | 6000 |",
"| 3 | Charlie | 35 | Finance | 7000 |",
"| 4 | Diana | 22 | Marketing | 4500 |",
"| 5 | Eve | 28 | IT | 5500 |",
"| 6 | Frank | 40 | Finance | 8000 |",
"| 7 | Grace | 26 | HR | 5200 |",
"| 8 | Henry | 32 | IT | 6200 |",
"| 9 | Ivy | 24 | Marketing | 4800 |",
"| 10 | Jack | 38 | Finance | 7500 |",
"+----------------+---------+-----+------------+--------+",
];
assert_batches_eq!(expected, &results);
pretty::print_batches(&results).unwrap();
}
#[tokio::test]
#[ignore]
async fn test_func_sql() {
let sql = "select count(s.id) from S3Object as s";
let input = SelectObjectContentInput {
bucket: "dandan".to_string(),
expected_bucket_owner: None,
key: "test.csv".to_string(),
sse_customer_algorithm: None,
sse_customer_key: None,
sse_customer_key_md5: None,
request: SelectObjectContentRequest {
expression: sql.to_string(),
expression_type: ExpressionType::from_static("SQL"),
input_serialization: InputSerialization {
csv: Some(CSVInput::default()),
..Default::default()
},
output_serialization: OutputSerialization {
csv: Some(CSVOutput::default()),
..Default::default()
},
request_progress: None,
scan_range: None,
},
};
let db = make_rustfsms(input.clone(), true).await.unwrap();
let query = Query::new(Context { input }, sql.to_string());
let result = db.execute(&query).await.unwrap();
let results = result.result().chunk_result().await.unwrap().to_vec();
pretty::print_batches(&results).unwrap();
}
}

View File

@@ -113,14 +113,14 @@ impl ContextProvider for MetadataProvider {
}
fn udf_names(&self) -> Vec<String> {
todo!()
self.func_manager.udfs()
}
fn udaf_names(&self) -> Vec<String> {
todo!()
self.func_manager.udafs()
}
fn udwf_names(&self) -> Vec<String> {
todo!()
self.func_manager.udwfs()
}
}