Skip to content

Commit 8eedff8

Browse files
Merge branch 'main' into doc-move-user-defined-plan
2 parents 8a044eb + 930620a commit 8eedff8

File tree

19 files changed

+126
-51
lines changed

19 files changed

+126
-51
lines changed

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -954,13 +954,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
954954
}
955955

956956
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
957-
let input_expr = acc_args
958-
.exprs
959-
.first()
960-
.ok_or(exec_datafusion_err!("Expected one argument"))?;
961-
let input_field = input_expr.return_field(acc_args.schema)?;
962-
963-
let double_output = input_field
957+
let double_output = acc_args.expr_fields[0]
964958
.metadata()
965959
.get("modify_values")
966960
.map(|v| v == "double_output")

datafusion/ffi/src/udaf/accumulator_args.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
9797
pub struct ForeignAccumulatorArgs {
9898
pub return_field: FieldRef,
9999
pub schema: Schema,
100+
pub expr_fields: Vec<FieldRef>,
100101
pub ignore_nulls: bool,
101102
pub order_bys: Vec<PhysicalSortExpr>,
102103
pub is_reversed: bool,
@@ -132,9 +133,15 @@ impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
132133

133134
let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?;
134135

136+
let expr_fields = exprs
137+
.iter()
138+
.map(|e| e.return_field(&schema))
139+
.collect::<Result<Vec<_>, _>>()?;
140+
135141
Ok(Self {
136142
return_field,
137143
schema,
144+
expr_fields,
138145
ignore_nulls: proto_def.ignore_nulls,
139146
order_bys,
140147
is_reversed: value.is_reversed,
@@ -150,6 +157,7 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
150157
Self {
151158
return_field: Arc::clone(&value.return_field),
152159
schema: &value.schema,
160+
expr_fields: &value.expr_fields,
153161
ignore_nulls: value.ignore_nulls,
154162
order_bys: &value.order_bys,
155163
is_reversed: value.is_reversed,
@@ -175,6 +183,7 @@ mod tests {
175183
let orig_args = AccumulatorArgs {
176184
return_field: Field::new("f", DataType::Float64, true).into(),
177185
schema: &schema,
186+
expr_fields: &[Field::new("a", DataType::Int32, true).into()],
178187
ignore_nulls: false,
179188
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
180189
is_reversed: false,

datafusion/ffi/src/udaf/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ mod tests {
705705
let acc_args = AccumulatorArgs {
706706
return_field: Field::new("f", DataType::Float64, true).into(),
707707
schema: &schema,
708+
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
708709
ignore_nulls: true,
709710
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
710711
is_reversed: false,
@@ -782,6 +783,7 @@ mod tests {
782783
let acc_args = AccumulatorArgs {
783784
return_field: Field::new("f", DataType::Float64, true).into(),
784785
schema: &schema,
786+
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
785787
ignore_nulls: true,
786788
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
787789
is_reversed: false,

datafusion/functions-aggregate-common/src/accumulator.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ pub struct AccumulatorArgs<'a> {
3030
/// The return field of the aggregate function.
3131
pub return_field: FieldRef,
3232

33-
/// The schema of the input arguments
33+
/// Input schema to the aggregate function. If you need to check data type, nullability
34+
/// or metadata of input arguments then you should use `expr_fields` below instead.
3435
pub schema: &'a Schema,
3536

3637
/// Whether to ignore nulls.
@@ -67,6 +68,9 @@ pub struct AccumulatorArgs<'a> {
6768

6869
/// The physical expression of arguments the aggregate function takes.
6970
pub exprs: &'a [Arc<dyn PhysicalExpr>],
71+
72+
/// Fields corresponding to each expr (same order & length).
73+
pub expr_fields: &'a [FieldRef],
7074
}
7175

7276
impl AccumulatorArgs<'_> {

datafusion/functions-aggregate/benches/count.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,17 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};
3333

3434
fn prepare_group_accumulator() -> Box<dyn GroupsAccumulator> {
3535
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)]));
36+
let expr = col("f", &schema).unwrap();
3637
let accumulator_args = AccumulatorArgs {
3738
return_field: Field::new("f", DataType::Int64, true).into(),
3839
schema: &schema,
40+
expr_fields: &[expr.return_field(&schema).unwrap()],
3941
ignore_nulls: false,
4042
order_bys: &[],
4143
is_reversed: false,
4244
name: "COUNT(f)",
4345
is_distinct: false,
44-
exprs: &[col("f", &schema).unwrap()],
46+
exprs: &[expr],
4547
};
4648
let count_fn = Count::new();
4749

@@ -56,15 +58,17 @@ fn prepare_accumulator() -> Box<dyn Accumulator> {
5658
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
5759
true,
5860
)]));
61+
let expr = col("f", &schema).unwrap();
5962
let accumulator_args = AccumulatorArgs {
6063
return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
6164
schema: &schema,
65+
expr_fields: &[expr.return_field(&schema).unwrap()],
6266
ignore_nulls: false,
6367
order_bys: &[],
6468
is_reversed: false,
6569
name: "COUNT(f)",
6670
is_distinct: true,
67-
exprs: &[col("f", &schema).unwrap()],
71+
exprs: &[expr],
6872
};
6973
let count_fn = Count::new();
7074

datafusion/functions-aggregate/benches/min_max_bytes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ fn create_max_bytes_accumulator() -> Box<dyn GroupsAccumulator> {
4444
max.create_groups_accumulator(AccumulatorArgs {
4545
return_field: Arc::new(Field::new("value", DataType::Utf8, true)),
4646
schema: &input_schema,
47+
expr_fields: &[Field::new("value", DataType::Utf8, true).into()],
4748
ignore_nulls: true,
4849
order_bys: &[],
4950
is_reversed: false,

datafusion/functions-aggregate/benches/sum.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ fn prepare_accumulator(data_type: &DataType) -> Box<dyn GroupsAccumulator> {
3131
let field = Field::new("f", data_type.clone(), true).into();
3232
let schema = Arc::new(Schema::new(vec![Arc::clone(&field)]));
3333
let accumulator_args = AccumulatorArgs {
34-
return_field: field,
34+
return_field: Arc::clone(&field),
3535
schema: &schema,
36+
expr_fields: &[field],
3637
ignore_nulls: false,
3738
order_bys: &[],
3839
is_reversed: false,

datafusion/functions-aggregate/src/approx_distinct.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ impl AggregateUDFImpl for ApproxDistinct {
361361
}
362362

363363
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
364-
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
364+
let data_type = acc_args.expr_fields[0].data_type();
365365

366366
let accumulator: Box<dyn Accumulator> = match data_type {
367367
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL

datafusion/functions-aggregate/src/approx_median.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ impl AggregateUDFImpl for ApproxMedian {
134134

135135
Ok(Box::new(ApproxPercentileAccumulator::new(
136136
0.5_f64,
137-
acc_args.exprs[0].data_type(acc_args.schema)?,
137+
acc_args.expr_fields[0].data_type().clone(),
138138
)))
139139
}
140140

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ impl ApproxPercentileCont {
187187
None
188188
};
189189

190-
let data_type = args.exprs[0].data_type(args.schema)?;
190+
let data_type = args.expr_fields[0].data_type();
191191
let accumulator: ApproxPercentileAccumulator = match data_type {
192-
t @ (DataType::UInt8
192+
DataType::UInt8
193193
| DataType::UInt16
194194
| DataType::UInt32
195195
| DataType::UInt64
@@ -198,12 +198,11 @@ impl ApproxPercentileCont {
198198
| DataType::Int32
199199
| DataType::Int64
200200
| DataType::Float32
201-
| DataType::Float64) => {
201+
| DataType::Float64 => {
202202
if let Some(max_size) = tdigest_max_size {
203-
ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
204-
}else{
205-
ApproxPercentileAccumulator::new(percentile, t)
206-
203+
ApproxPercentileAccumulator::new_with_max_size(percentile, data_type.clone(), max_size)
204+
} else {
205+
ApproxPercentileAccumulator::new(percentile, data_type.clone())
207206
}
208207
}
209208
other => {

0 commit comments

Comments
 (0)