Skip to content

Commit 0973140

Browse files
authored
refactor(core): self-manage the list of the core functions (#1346)
1 parent 52d0150 commit 0973140

File tree

11 files changed

+397
-103
lines changed

11 files changed

+397
-103
lines changed

wren-core-py/src/context.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use tokio::runtime::Runtime;
3232
use wren_core::array::AsArray;
3333
use wren_core::ast::{visit_statements_mut, Expr, Statement, Value, ValueWithSpan};
3434
use wren_core::dialect::GenericDialect;
35-
use wren_core::mdl::context::create_ctx_with_mdl;
35+
use wren_core::mdl::context::apply_wren_on_ctx;
3636
use wren_core::mdl::function::{
3737
ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType,
3838
RemoteFunction,
@@ -91,11 +91,11 @@ impl PySessionContext {
9191
.collect::<Vec<_>>();
9292

9393
let config = SessionConfig::default().with_information_schema(true);
94-
let ctx = wren_core::SessionContext::new_with_config(config);
94+
let ctx = wren_core::mdl::create_wren_ctx(Some(config));
9595
let runtime = Runtime::new().map_err(CoreError::from)?;
9696

9797
let registered_functions = runtime
98-
.block_on(Self::get_regietered_functions(&ctx))
98+
.block_on(Self::get_registered_functions(&ctx))
9999
.map(|functions| {
100100
functions
101101
.into_iter()
@@ -169,7 +169,7 @@ impl PySessionContext {
169169
Ok(analyzed_mdl) => {
170170
let analyzed_mdl = Arc::new(analyzed_mdl);
171171
let unparser_ctx = runtime
172-
.block_on(create_ctx_with_mdl(
172+
.block_on(apply_wren_on_ctx(
173173
&ctx,
174174
Arc::clone(&analyzed_mdl),
175175
Arc::clone(&properties_ref),
@@ -178,7 +178,7 @@ impl PySessionContext {
178178
.map_err(CoreError::from)?;
179179

180180
let exec_ctx = runtime
181-
.block_on(create_ctx_with_mdl(
181+
.block_on(apply_wren_on_ctx(
182182
&ctx,
183183
Arc::clone(&analyzed_mdl),
184184
Arc::clone(&properties_ref),
@@ -226,7 +226,7 @@ impl PySessionContext {
226226
pub fn get_available_functions(&self) -> PyResult<Vec<PyRemoteFunction>> {
227227
let registered_functions: Vec<PyRemoteFunction> = self
228228
.runtime
229-
.block_on(Self::get_regietered_functions(&self.exec_ctx))
229+
.block_on(Self::get_registered_functions(&self.exec_ctx))
230230
.map_err(CoreError::from)?
231231
.into_iter()
232232
.map(|f| f.into())
@@ -321,7 +321,7 @@ impl PySessionContext {
321321
/// The `name` is the name of the function.
322322
/// The `function_type` is the type of the function. (e.g. scalar, aggregate, window)
323323
/// The `description` is the description of the function.
324-
async fn get_regietered_functions(
324+
async fn get_registered_functions(
325325
ctx: &wren_core::SessionContext,
326326
) -> PyResult<Vec<RemoteFunctionDto>> {
327327
let sql = r#"

wren-core/core/src/mdl/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use parking_lot::RwLock;
4343
pub type SessionPropertiesRef = Arc<HashMap<String, Option<String>>>;
4444

4545
/// Apply Wren Rules to the context for sql generation.
46-
pub async fn create_ctx_with_mdl(
46+
pub async fn apply_wren_on_ctx(
4747
ctx: &SessionContext,
4848
analyzed_mdl: Arc<AnalyzedWrenMDL>,
4949
properties: SessionPropertiesRef,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::{
4+
functions_aggregate::{
5+
approx_percentile_cont::approx_percentile_cont_udaf,
6+
approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf, *,
7+
},
8+
logical_expr::AggregateUDF,
9+
};
10+
11+
pub fn aggregate_functions() -> Vec<Arc<AggregateUDF>> {
12+
vec![
13+
array_agg::array_agg_udaf(),
14+
first_last::first_value_udaf(),
15+
first_last::last_value_udaf(),
16+
covariance::covar_samp_udaf(),
17+
covariance::covar_pop_udaf(),
18+
correlation::corr_udaf(),
19+
sum::sum_udaf(),
20+
min_max::max_udaf(),
21+
min_max::min_udaf(),
22+
median::median_udaf(),
23+
count::count_udaf(),
24+
regr::regr_slope_udaf(),
25+
regr::regr_intercept_udaf(),
26+
regr::regr_count_udaf(),
27+
regr::regr_r2_udaf(),
28+
regr::regr_avgx_udaf(),
29+
regr::regr_avgy_udaf(),
30+
regr::regr_sxx_udaf(),
31+
regr::regr_syy_udaf(),
32+
regr::regr_sxy_udaf(),
33+
variance::var_samp_udaf(),
34+
variance::var_pop_udaf(),
35+
stddev::stddev_udaf(),
36+
stddev::stddev_pop_udaf(),
37+
approx_median::approx_median_udaf(),
38+
approx_distinct::approx_distinct_udaf(),
39+
approx_percentile_cont_udaf(),
40+
approx_percentile_cont_with_weight_udaf(),
41+
string_agg::string_agg_udaf(),
42+
bit_and_or_xor::bit_and_udaf(),
43+
bit_and_or_xor::bit_or_udaf(),
44+
bit_and_or_xor::bit_xor_udaf(),
45+
bool_and_or::bool_and_udaf(),
46+
bool_and_or::bool_or_udaf(),
47+
average::avg_udaf(),
48+
grouping::grouping_udaf(),
49+
nth_value::nth_value_udaf(),
50+
]
51+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
mod aggregate;
2+
mod remote_function;
3+
mod scalar;
4+
mod table;
5+
mod window;
6+
pub use aggregate::aggregate_functions;
7+
pub use remote_function::*;
8+
pub use scalar::scalar_functions;
9+
pub use table::table_functions;
10+
pub use window::window_functions;
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::{
4+
functions::{
5+
core::*, crypto::*, datetime::*, encoding::*, math::*, regex::*, string::*,
6+
unicode::*,
7+
},
8+
functions_nested::*,
9+
logical_expr::ScalarUDF,
10+
};
11+
12+
pub fn scalar_functions() -> Vec<Arc<ScalarUDF>> {
13+
vec![
14+
// datefusion core
15+
nullif(),
16+
arrow_cast(),
17+
nvl(),
18+
nvl2(),
19+
overlay(),
20+
arrow_typeof(),
21+
named_struct(),
22+
get_field(),
23+
coalesce(),
24+
greatest(),
25+
least(),
26+
union_extract(),
27+
union_tag(),
28+
version(),
29+
r#struct(),
30+
// datafusion crypto
31+
digest(),
32+
md5(),
33+
sha224(),
34+
sha256(),
35+
sha384(),
36+
sha512(),
37+
// datafusion datetime
38+
current_date(),
39+
current_time(),
40+
date_bin(),
41+
date_part(),
42+
date_trunc(),
43+
date_diff(),
44+
from_unixtime(),
45+
make_date(),
46+
now(),
47+
to_char(),
48+
to_date(),
49+
to_local_time(),
50+
to_unixtime(),
51+
to_timestamp(),
52+
to_timestamp_seconds(),
53+
to_timestamp_millis(),
54+
to_timestamp_micros(),
55+
to_timestamp_nanos(),
56+
// datafusion encoding
57+
encode(),
58+
decode(),
59+
// datafusion math
60+
abs(),
61+
acos(),
62+
acosh(),
63+
asin(),
64+
asinh(),
65+
atan(),
66+
atan2(),
67+
atanh(),
68+
cbrt(),
69+
ceil(),
70+
cos(),
71+
cosh(),
72+
cot(),
73+
degrees(),
74+
exp(),
75+
factorial(),
76+
floor(),
77+
gcd(),
78+
isnan(),
79+
iszero(),
80+
lcm(),
81+
ln(),
82+
log(),
83+
log2(),
84+
log10(),
85+
nanvl(),
86+
pi(),
87+
power(),
88+
radians(),
89+
random(),
90+
signum(),
91+
sin(),
92+
sinh(),
93+
sqrt(),
94+
tan(),
95+
tanh(),
96+
round(),
97+
trunc(),
98+
// datafusion regex
99+
regexp_count(),
100+
regexp_match(),
101+
regexp_instr(),
102+
regexp_like(),
103+
regexp_replace(),
104+
// datafusion string
105+
ascii(),
106+
bit_length(),
107+
btrim(),
108+
chr(),
109+
concat(),
110+
concat_ws(),
111+
ends_with(),
112+
levenshtein(),
113+
lower(),
114+
ltrim(),
115+
octet_length(),
116+
repeat(),
117+
replace(),
118+
rtrim(),
119+
split_part(),
120+
starts_with(),
121+
to_hex(),
122+
upper(),
123+
uuid(),
124+
contains(),
125+
// datafusion unicode
126+
character_length(),
127+
find_in_set(),
128+
initcap(),
129+
left(),
130+
lpad(),
131+
reverse(),
132+
right(),
133+
rpad(),
134+
strpos(),
135+
substr(),
136+
substr_index(),
137+
translate(),
138+
// datafusion nested
139+
string::array_to_string_udf(),
140+
string::string_to_array_udf(),
141+
range::range_udf(),
142+
range::gen_series_udf(),
143+
dimension::array_dims_udf(),
144+
cardinality::cardinality_udf(),
145+
dimension::array_ndims_udf(),
146+
datafusion::functions_nested::concat::array_append_udf(),
147+
datafusion::functions_nested::concat::array_prepend_udf(),
148+
datafusion::functions_nested::concat::array_concat_udf(),
149+
except::array_except_udf(),
150+
extract::array_element_udf(),
151+
extract::array_pop_back_udf(),
152+
extract::array_pop_front_udf(),
153+
extract::array_slice_udf(),
154+
extract::array_any_value_udf(),
155+
make_array::make_array_udf(),
156+
array_has::array_has_udf(),
157+
array_has::array_has_all_udf(),
158+
array_has::array_has_any_udf(),
159+
empty::array_empty_udf(),
160+
length::array_length_udf(),
161+
distance::array_distance_udf(),
162+
flatten::flatten_udf(),
163+
min_max::array_max_udf(),
164+
min_max::array_min_udf(),
165+
sort::array_sort_udf(),
166+
datafusion::functions_nested::repeat::array_repeat_udf(),
167+
resize::array_resize_udf(),
168+
datafusion::functions_nested::reverse::array_reverse_udf(),
169+
set_ops::array_distinct_udf(),
170+
set_ops::array_intersect_udf(),
171+
set_ops::array_union_udf(),
172+
position::array_position_udf(),
173+
position::array_positions_udf(),
174+
remove::array_remove_udf(),
175+
remove::array_remove_all_udf(),
176+
remove::array_remove_n_udf(),
177+
datafusion::functions_nested::replace::array_replace_n_udf(),
178+
datafusion::functions_nested::replace::array_replace_all_udf(),
179+
datafusion::functions_nested::replace::array_replace_udf(),
180+
map::map_udf(),
181+
map_entries::map_entries_udf(),
182+
map_extract::map_extract_udf(),
183+
map_keys::map_keys_udf(),
184+
map_values::map_values_udf(),
185+
]
186+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::{
4+
catalog::TableFunction,
5+
functions_table::{generate_series, range},
6+
};
7+
8+
/// Returns all default table functions
9+
pub fn table_functions() -> Vec<Arc<TableFunction>> {
10+
vec![generate_series(), range()]
11+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::{functions_window::*, logical_expr::WindowUDF};
4+
5+
pub fn window_functions() -> Vec<Arc<WindowUDF>> {
6+
vec![
7+
cume_dist::cume_dist_udwf(),
8+
row_number::row_number_udwf(),
9+
lead_lag::lead_udwf(),
10+
lead_lag::lag_udwf(),
11+
rank::rank_udwf(),
12+
rank::dense_rank_udwf(),
13+
rank::percent_rank_udwf(),
14+
ntile::ntile_udwf(),
15+
nth_value::first_value_udwf(),
16+
nth_value::last_value_udwf(),
17+
nth_value::nth_value_udwf(),
18+
]
19+
}

0 commit comments

Comments
 (0)