mirror of
https://github.com/rustfs/rustfs.git
synced 2026-01-17 09:40:32 +00:00
@@ -1,4 +1,3 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
|
||||
@@ -7,11 +6,11 @@ use crate::QueryResult;
|
||||
|
||||
pub type FuncMetaManagerRef = Arc<dyn FunctionMetadataManager + Send + Sync>;
|
||||
pub trait FunctionMetadataManager {
|
||||
fn register_udf(&mut self, udf: ScalarUDF) -> QueryResult<()>;
|
||||
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> QueryResult<()>;
|
||||
|
||||
fn register_udaf(&mut self, udaf: AggregateUDF) -> QueryResult<()>;
|
||||
fn register_udaf(&mut self, udaf: Arc<AggregateUDF>) -> QueryResult<()>;
|
||||
|
||||
fn register_udwf(&mut self, udwf: WindowUDF) -> QueryResult<()>;
|
||||
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> QueryResult<()>;
|
||||
|
||||
fn udf(&self, name: &str) -> QueryResult<Arc<ScalarUDF>>;
|
||||
|
||||
@@ -19,5 +18,7 @@ pub trait FunctionMetadataManager {
|
||||
|
||||
fn udwf(&self, name: &str) -> QueryResult<Arc<WindowUDF>>;
|
||||
|
||||
fn udfs(&self) -> HashSet<String>;
|
||||
fn udfs(&self) -> Vec<String>;
|
||||
fn udafs(&self) -> Vec<String>;
|
||||
fn udwfs(&self) -> Vec<String>;
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ impl SimpleQueryDispatcher {
|
||||
let path = format!("s3://{}/{}", self.input.bucket, self.input.key);
|
||||
let table_path = ListingTableUrl::parse(path)?;
|
||||
let listing_options = if self.input.request.input_serialization.csv.is_some() {
|
||||
let file_format = CsvFormat::default().with_options(CsvOptions::default().with_has_header(false));
|
||||
let file_format = CsvFormat::default().with_options(CsvOptions::default().with_has_header(true));
|
||||
ListingOptions::new(Arc::new(file_format)).with_file_extension(".csv")
|
||||
} else if self.input.request.input_serialization.parquet.is_some() {
|
||||
let file_format = ParquetFormat::new();
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::query::function::FunctionMetadataManager;
|
||||
use api::{QueryError, QueryResult};
|
||||
use datafusion::execution::SessionStateDefaults;
|
||||
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
|
||||
use tracing::debug;
|
||||
|
||||
pub type SimpleFunctionMetadataManagerRef = Arc<SimpleFunctionMetadataManager>;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug)]
|
||||
pub struct SimpleFunctionMetadataManager {
|
||||
/// Scalar functions that are registered with the context
|
||||
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
|
||||
@@ -17,19 +19,53 @@ pub struct SimpleFunctionMetadataManager {
|
||||
pub window_functions: HashMap<String, Arc<WindowUDF>>,
|
||||
}
|
||||
|
||||
impl Default for SimpleFunctionMetadataManager {
|
||||
fn default() -> Self {
|
||||
let mut func_meta_manager = Self {
|
||||
scalar_functions: Default::default(),
|
||||
aggregate_functions: Default::default(),
|
||||
window_functions: Default::default(),
|
||||
};
|
||||
SessionStateDefaults::default_scalar_functions().into_iter().for_each(|udf| {
|
||||
let existing_udf = func_meta_manager.register_udf(udf.clone());
|
||||
if let Ok(()) = existing_udf {
|
||||
debug!("Overwrote an existing UDF: {}", udf.name());
|
||||
}
|
||||
});
|
||||
|
||||
SessionStateDefaults::default_aggregate_functions()
|
||||
.into_iter()
|
||||
.for_each(|udaf| {
|
||||
let existing_udaf = func_meta_manager.register_udaf(udaf.clone());
|
||||
if let Ok(()) = existing_udaf {
|
||||
debug!("Overwrote an existing UDAF: {}", udaf.name());
|
||||
}
|
||||
});
|
||||
|
||||
SessionStateDefaults::default_window_functions().into_iter().for_each(|udwf| {
|
||||
let existing_udwf = func_meta_manager.register_udwf(udwf.clone());
|
||||
if let Ok(()) = existing_udwf {
|
||||
debug!("Overwrote an existing UDWF: {}", udwf.name());
|
||||
}
|
||||
});
|
||||
|
||||
func_meta_manager
|
||||
}
|
||||
}
|
||||
|
||||
impl FunctionMetadataManager for SimpleFunctionMetadataManager {
|
||||
fn register_udf(&mut self, f: ScalarUDF) -> QueryResult<()> {
|
||||
self.scalar_functions.insert(f.inner().name().to_uppercase(), Arc::new(f));
|
||||
fn register_udf(&mut self, f: Arc<ScalarUDF>) -> QueryResult<()> {
|
||||
self.scalar_functions.insert(f.inner().name().to_uppercase(), f);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_udaf(&mut self, f: AggregateUDF) -> QueryResult<()> {
|
||||
self.aggregate_functions.insert(f.inner().name().to_uppercase(), Arc::new(f));
|
||||
fn register_udaf(&mut self, f: Arc<AggregateUDF>) -> QueryResult<()> {
|
||||
self.aggregate_functions.insert(f.inner().name().to_uppercase(), f);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_udwf(&mut self, f: WindowUDF) -> QueryResult<()> {
|
||||
self.window_functions.insert(f.inner().name().to_uppercase(), Arc::new(f));
|
||||
fn register_udwf(&mut self, f: Arc<WindowUDF>) -> QueryResult<()> {
|
||||
self.window_functions.insert(f.inner().name().to_uppercase(), f);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -57,7 +93,13 @@ impl FunctionMetadataManager for SimpleFunctionMetadataManager {
|
||||
.ok_or_else(|| QueryError::FunctionNotExists { name: name.to_string() })
|
||||
}
|
||||
|
||||
fn udfs(&self) -> HashSet<String> {
|
||||
fn udfs(&self) -> Vec<String> {
|
||||
self.scalar_functions.keys().cloned().collect()
|
||||
}
|
||||
fn udafs(&self) -> Vec<String> {
|
||||
self.aggregate_functions.keys().cloned().collect()
|
||||
}
|
||||
fn udwfs(&self) -> Vec<String> {
|
||||
self.window_functions.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,24 +141,58 @@ mod tests {
|
||||
let results = result.result().chunk_result().await.unwrap().to_vec();
|
||||
|
||||
let expected = [
|
||||
"+----------------+----------+----------+------------+----------+",
|
||||
"| column_1 | column_2 | column_3 | column_4 | column_5 |",
|
||||
"+----------------+----------+----------+------------+----------+",
|
||||
"| id | name | age | department | salary |",
|
||||
"| 1 | Alice | 25 | HR | 5000 |",
|
||||
"| 2 | Bob | 30 | IT | 6000 |",
|
||||
"| 3 | Charlie | 35 | Finance | 7000 |",
|
||||
"| 4 | Diana | 22 | Marketing | 4500 |",
|
||||
"| 5 | Eve | 28 | IT | 5500 |",
|
||||
"| 6 | Frank | 40 | Finance | 8000 |",
|
||||
"| 7 | Grace | 26 | HR | 5200 |",
|
||||
"| 8 | Henry | 32 | IT | 6200 |",
|
||||
"| 9 | Ivy | 24 | Marketing | 4800 |",
|
||||
"| 10 | Jack | 38 | Finance | 7500 |",
|
||||
"+----------------+----------+----------+------------+----------+",
|
||||
"+----------------+---------+-----+------------+--------+",
|
||||
"| id | name | age | department | salary |",
|
||||
"+----------------+---------+-----+------------+--------+",
|
||||
"| 1 | Alice | 25 | HR | 5000 |",
|
||||
"| 2 | Bob | 30 | IT | 6000 |",
|
||||
"| 3 | Charlie | 35 | Finance | 7000 |",
|
||||
"| 4 | Diana | 22 | Marketing | 4500 |",
|
||||
"| 5 | Eve | 28 | IT | 5500 |",
|
||||
"| 6 | Frank | 40 | Finance | 8000 |",
|
||||
"| 7 | Grace | 26 | HR | 5200 |",
|
||||
"| 8 | Henry | 32 | IT | 6200 |",
|
||||
"| 9 | Ivy | 24 | Marketing | 4800 |",
|
||||
"| 10 | Jack | 38 | Finance | 7500 |",
|
||||
"+----------------+---------+-----+------------+--------+",
|
||||
];
|
||||
|
||||
assert_batches_eq!(expected, &results);
|
||||
pretty::print_batches(&results).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_func_sql() {
|
||||
let sql = "select count(s.id) from S3Object as s";
|
||||
let input = SelectObjectContentInput {
|
||||
bucket: "dandan".to_string(),
|
||||
expected_bucket_owner: None,
|
||||
key: "test.csv".to_string(),
|
||||
sse_customer_algorithm: None,
|
||||
sse_customer_key: None,
|
||||
sse_customer_key_md5: None,
|
||||
request: SelectObjectContentRequest {
|
||||
expression: sql.to_string(),
|
||||
expression_type: ExpressionType::from_static("SQL"),
|
||||
input_serialization: InputSerialization {
|
||||
csv: Some(CSVInput::default()),
|
||||
..Default::default()
|
||||
},
|
||||
output_serialization: OutputSerialization {
|
||||
csv: Some(CSVOutput::default()),
|
||||
..Default::default()
|
||||
},
|
||||
request_progress: None,
|
||||
scan_range: None,
|
||||
},
|
||||
};
|
||||
let db = make_rustfsms(input.clone(), true).await.unwrap();
|
||||
let query = Query::new(Context { input }, sql.to_string());
|
||||
|
||||
let result = db.execute(&query).await.unwrap();
|
||||
|
||||
let results = result.result().chunk_result().await.unwrap().to_vec();
|
||||
pretty::print_batches(&results).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,14 +113,14 @@ impl ContextProvider for MetadataProvider {
|
||||
}
|
||||
|
||||
fn udf_names(&self) -> Vec<String> {
|
||||
todo!()
|
||||
self.func_manager.udfs()
|
||||
}
|
||||
|
||||
fn udaf_names(&self) -> Vec<String> {
|
||||
todo!()
|
||||
self.func_manager.udafs()
|
||||
}
|
||||
|
||||
fn udwf_names(&self) -> Vec<String> {
|
||||
todo!()
|
||||
self.func_manager.udwfs()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user