Skip to content

Commit 6d1ae75

Browse files
committed
Factor out Substrait consumers into separate files
1 parent 280997d commit 6d1ae75

36 files changed

+4272
-3452
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

-3,452
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use super::{from_substrait_func_args, substrait_fun_name, SubstraitConsumer};
19+
use datafusion::common::{not_impl_datafusion_err, plan_err, DFSchema, ScalarValue};
20+
use datafusion::execution::FunctionRegistry;
21+
use datafusion::logical_expr::{expr, Expr, SortExpr};
22+
use std::sync::Arc;
23+
use substrait::proto::AggregateFunction;
24+
25+
/// Convert Substrait AggregateFunction to DataFusion Expr
26+
pub async fn from_substrait_agg_func(
27+
consumer: &impl SubstraitConsumer,
28+
f: &AggregateFunction,
29+
input_schema: &DFSchema,
30+
filter: Option<Box<Expr>>,
31+
order_by: Option<Vec<SortExpr>>,
32+
distinct: bool,
33+
) -> datafusion::common::Result<Arc<Expr>> {
34+
let Some(fn_signature) = consumer
35+
.get_extensions()
36+
.functions
37+
.get(&f.function_reference)
38+
else {
39+
return plan_err!(
40+
"Aggregate function not registered: function anchor = {:?}",
41+
f.function_reference
42+
);
43+
};
44+
45+
let fn_name = substrait_fun_name(fn_signature);
46+
let udaf = consumer.get_function_registry().udaf(fn_name);
47+
let udaf = udaf.map_err(|_| {
48+
not_impl_datafusion_err!(
49+
"Aggregate function {} is not supported: function anchor = {:?}",
50+
fn_signature,
51+
f.function_reference
52+
)
53+
})?;
54+
55+
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
56+
57+
// Datafusion does not support aggregate functions with no arguments, so
58+
// we inject a dummy argument that does not affect the query, but allows
59+
// us to bypass this limitation.
60+
let args = if udaf.name() == "count" && args.is_empty() {
61+
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
62+
} else {
63+
args
64+
};
65+
66+
Ok(Arc::new(Expr::AggregateFunction(
67+
expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None),
68+
)))
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use super::grouping::from_substrait_grouping;
2+
use super::SubstraitConsumer;
3+
use super::{from_substrait_agg_func, from_substrait_sorts};
4+
use datafusion::common::not_impl_err;
5+
use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder};
6+
use substrait::proto::aggregate_function::AggregationInvocation;
7+
use substrait::proto::AggregateRel;
8+
9+
pub async fn from_aggregate_rel(
10+
consumer: &impl SubstraitConsumer,
11+
agg: &AggregateRel,
12+
) -> datafusion::common::Result<LogicalPlan> {
13+
if let Some(input) = agg.input.as_ref() {
14+
let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?);
15+
let mut ref_group_exprs = vec![];
16+
17+
for e in &agg.grouping_expressions {
18+
let x = consumer.consume_expression(e, input.schema()).await?;
19+
ref_group_exprs.push(x);
20+
}
21+
22+
let mut group_exprs = vec![];
23+
let mut aggr_exprs = vec![];
24+
25+
match agg.groupings.len() {
26+
1 => {
27+
group_exprs.extend_from_slice(
28+
&from_substrait_grouping(
29+
consumer,
30+
&agg.groupings[0],
31+
&ref_group_exprs,
32+
input.schema(),
33+
)
34+
.await?,
35+
);
36+
}
37+
_ => {
38+
let mut grouping_sets = vec![];
39+
for grouping in &agg.groupings {
40+
let grouping_set = from_substrait_grouping(
41+
consumer,
42+
grouping,
43+
&ref_group_exprs,
44+
input.schema(),
45+
)
46+
.await?;
47+
grouping_sets.push(grouping_set);
48+
}
49+
// Single-element grouping expression of type Expr::GroupingSet.
50+
// Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when
51+
// parsed by the producer and consumer, since Substrait does not have a type dedicated
52+
// to ROLLUP. Only vector of Groupings (grouping sets) is available.
53+
group_exprs
54+
.push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)));
55+
}
56+
};
57+
58+
for m in &agg.measures {
59+
let filter = match &m.filter {
60+
Some(fil) => Some(Box::new(
61+
consumer.consume_expression(fil, input.schema()).await?,
62+
)),
63+
None => None,
64+
};
65+
let agg_func = match &m.measure {
66+
Some(f) => {
67+
let distinct = match f.invocation {
68+
_ if f.invocation == AggregationInvocation::Distinct as i32 => {
69+
true
70+
}
71+
_ if f.invocation == AggregationInvocation::All as i32 => false,
72+
_ => false,
73+
};
74+
let order_by = if !f.sorts.is_empty() {
75+
Some(
76+
from_substrait_sorts(consumer, &f.sorts, input.schema())
77+
.await?,
78+
)
79+
} else {
80+
None
81+
};
82+
83+
from_substrait_agg_func(
84+
consumer,
85+
f,
86+
input.schema(),
87+
filter,
88+
order_by,
89+
distinct,
90+
)
91+
.await
92+
}
93+
None => {
94+
not_impl_err!("Aggregate without aggregate function is not supported")
95+
}
96+
};
97+
aggr_exprs.push(agg_func?.as_ref().clone());
98+
}
99+
input.aggregate(group_exprs, aggr_exprs)?.build()
100+
} else {
101+
not_impl_err!("Aggregate without an input is not valid")
102+
}
103+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use datafusion::common::{plan_err, substrait_err, ScalarValue};
18+
use datafusion::logical_expr::WindowFrameBound;
19+
use substrait::proto::expression::{
20+
window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind,
21+
window_function::Bound,
22+
};
23+
24+
pub(super) fn from_substrait_bound(
25+
bound: &Option<Bound>,
26+
is_lower: bool,
27+
) -> datafusion::common::Result<WindowFrameBound> {
28+
match bound {
29+
Some(b) => match &b.kind {
30+
Some(k) => match k {
31+
BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
32+
Ok(WindowFrameBound::CurrentRow)
33+
}
34+
BoundKind::Preceding(SubstraitBound::Preceding { offset }) => {
35+
if *offset <= 0 {
36+
return plan_err!("Preceding bound must be positive");
37+
}
38+
Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
39+
*offset as u64,
40+
))))
41+
}
42+
BoundKind::Following(SubstraitBound::Following { offset }) => {
43+
if *offset <= 0 {
44+
return plan_err!("Following bound must be positive");
45+
}
46+
Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some(
47+
*offset as u64,
48+
))))
49+
}
50+
BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
51+
if is_lower {
52+
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
53+
} else {
54+
Ok(WindowFrameBound::Following(ScalarValue::Null))
55+
}
56+
}
57+
},
58+
None => substrait_err!("WindowFunction missing Substrait Bound kind"),
59+
},
60+
None => {
61+
if is_lower {
62+
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
63+
} else {
64+
Ok(WindowFrameBound::Following(ScalarValue::Null))
65+
}
66+
}
67+
}
68+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use super::r#type::from_substrait_type_without_names;
18+
use super::SubstraitConsumer;
19+
use datafusion::common::{substrait_err, DFSchema};
20+
use datafusion::logical_expr::{Cast, Expr, TryCast};
21+
use substrait::proto::expression as substrait_expression;
22+
use substrait::proto::expression::cast::FailureBehavior::ReturnNull;
23+
24+
pub async fn from_cast(
25+
consumer: &impl SubstraitConsumer,
26+
cast: &substrait_expression::Cast,
27+
input_schema: &DFSchema,
28+
) -> datafusion::common::Result<Expr> {
29+
match cast.r#type.as_ref() {
30+
Some(output_type) => {
31+
let input_expr = Box::new(
32+
consumer
33+
.consume_expression(
34+
cast.input.as_ref().unwrap().as_ref(),
35+
input_schema,
36+
)
37+
.await?,
38+
);
39+
let data_type = from_substrait_type_without_names(consumer, output_type)?;
40+
if cast.failure_behavior() == ReturnNull {
41+
Ok(Expr::TryCast(TryCast::new(input_expr, data_type)))
42+
} else {
43+
Ok(Expr::Cast(Cast::new(input_expr, data_type)))
44+
}
45+
}
46+
None => substrait_err!("Cast expression without output type is not allowed"),
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use super::utils::requalify_sides_if_needed;
18+
use super::SubstraitConsumer;
19+
use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
20+
use substrait::proto::CrossRel;
21+
22+
pub async fn from_cross_rel(
23+
consumer: &impl SubstraitConsumer,
24+
cross: &CrossRel,
25+
) -> datafusion::common::Result<LogicalPlan> {
26+
let left = LogicalPlanBuilder::from(
27+
consumer.consume_rel(cross.left.as_ref().unwrap()).await?,
28+
);
29+
let right = LogicalPlanBuilder::from(
30+
consumer.consume_rel(cross.right.as_ref().unwrap()).await?,
31+
);
32+
let (left, right) = requalify_sides_if_needed(left, right)?;
33+
left.cross_join(right.build()?)?.build()
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use super::from_substrait_field_reference;
18+
use super::SubstraitConsumer;
19+
use datafusion::common::{not_impl_err, substrait_err};
20+
use datafusion::logical_expr::{LogicalPlan, Partitioning, Repartition};
21+
use std::sync::Arc;
22+
use substrait::proto::exchange_rel::ExchangeKind;
23+
use substrait::proto::ExchangeRel;
24+
25+
pub async fn from_exchange_rel(
26+
consumer: &impl SubstraitConsumer,
27+
exchange: &ExchangeRel,
28+
) -> datafusion::common::Result<LogicalPlan> {
29+
let Some(input) = exchange.input.as_ref() else {
30+
return substrait_err!("Unexpected empty input in ExchangeRel");
31+
};
32+
let input = Arc::new(consumer.consume_rel(input).await?);
33+
34+
let Some(exchange_kind) = &exchange.exchange_kind else {
35+
return substrait_err!("Unexpected empty input in ExchangeRel");
36+
};
37+
38+
// ref: https://substrait.io/relations/physical_relations/#exchange-types
39+
let partitioning_scheme = match exchange_kind {
40+
ExchangeKind::ScatterByFields(scatter_fields) => {
41+
let mut partition_columns = vec![];
42+
let input_schema = input.schema();
43+
for field_ref in &scatter_fields.fields {
44+
let column = from_substrait_field_reference(field_ref, input_schema)?;
45+
partition_columns.push(column);
46+
}
47+
Partitioning::Hash(partition_columns, exchange.partition_count as usize)
48+
}
49+
ExchangeKind::RoundRobin(_) => {
50+
Partitioning::RoundRobinBatch(exchange.partition_count as usize)
51+
}
52+
ExchangeKind::SingleTarget(_)
53+
| ExchangeKind::MultiTarget(_)
54+
| ExchangeKind::Broadcast(_) => {
55+
return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}");
56+
}
57+
};
58+
Ok(LogicalPlan::Repartition(Repartition {
59+
input,
60+
partitioning_scheme,
61+
}))
62+
}

0 commit comments

Comments
 (0)