Skip to content

Commit fffc8be

Browse files
authored
feat: support the ergonomics of getting list slice with stride (#8946)
* support list stride * add test * fix fmt * rename and extend ListRange to ListStride * fix ci * fix doctest * fix conflict and keep ListRange * clean up thde code * chore * fix ci
1 parent 1097dc0 commit fffc8be

File tree

17 files changed

+282
-80
lines changed

17 files changed

+282
-80
lines changed

datafusion/core/src/physical_planner.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
209209
let key = create_physical_name(key, false)?;
210210
format!("{expr}[{key}]")
211211
}
212-
GetFieldAccess::ListRange { start, stop } => {
212+
GetFieldAccess::ListRange {
213+
start,
214+
stop,
215+
stride,
216+
} => {
213217
let start = create_physical_name(start, false)?;
214218
let stop = create_physical_name(stop, false)?;
215-
format!("{expr}[{start}:{stop}]")
219+
let stride = create_physical_name(stride, false)?;
220+
format!("{expr}[{start}:{stop}:{stride}]")
216221
}
217222
};
218223

datafusion/expr/src/expr.rs

+21-7
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,12 @@ pub enum GetFieldAccess {
421421
NamedStructField { name: ScalarValue },
422422
/// Single list index, for example: `list[i]`
423423
ListIndex { key: Box<Expr> },
424-
/// List range, for example `list[i:j]`
425-
ListRange { start: Box<Expr>, stop: Box<Expr> },
424+
/// List stride, for example `list[i:j:k]`
425+
ListRange {
426+
start: Box<Expr>,
427+
stop: Box<Expr>,
428+
stride: Box<Expr>,
429+
},
426430
}
427431

428432
/// Returns the field of a [`arrow::array::ListArray`] or
@@ -1209,14 +1213,15 @@ impl Expr {
12091213
/// # use datafusion_expr::{lit, col};
12101214
/// let expr = col("c1")
12111215
/// .range(lit(2), lit(4));
1212-
/// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4)]");
1216+
/// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4):Int64(1)]");
12131217
/// ```
12141218
pub fn range(self, start: Expr, stop: Expr) -> Self {
12151219
Expr::GetIndexedField(GetIndexedField {
12161220
expr: Box::new(self),
12171221
field: GetFieldAccess::ListRange {
12181222
start: Box::new(start),
12191223
stop: Box::new(stop),
1224+
stride: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
12201225
},
12211226
})
12221227
}
@@ -1530,8 +1535,12 @@ impl fmt::Display for Expr {
15301535
write!(f, "({expr})[{name}]")
15311536
}
15321537
GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"),
1533-
GetFieldAccess::ListRange { start, stop } => {
1534-
write!(f, "({expr})[{start}:{stop}]")
1538+
GetFieldAccess::ListRange {
1539+
start,
1540+
stop,
1541+
stride,
1542+
} => {
1543+
write!(f, "({expr})[{start}:{stop}:{stride}]")
15351544
}
15361545
},
15371546
Expr::GroupingSet(grouping_sets) => match grouping_sets {
@@ -1732,10 +1741,15 @@ fn create_name(e: &Expr) -> Result<String> {
17321741
let key = create_name(key)?;
17331742
Ok(format!("{expr}[{key}]"))
17341743
}
1735-
GetFieldAccess::ListRange { start, stop } => {
1744+
GetFieldAccess::ListRange {
1745+
start,
1746+
stop,
1747+
stride,
1748+
} => {
17361749
let start = create_name(start)?;
17371750
let stop = create_name(stop)?;
1738-
Ok(format!("{expr}[{start}:{stop}]"))
1751+
let stride = create_name(stride)?;
1752+
Ok(format!("{expr}[{start}:{stop}:{stride}]"))
17391753
}
17401754
}
17411755
}

datafusion/expr/src/expr_schema.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,14 @@ fn field_for_index<S: ExprSchema>(
374374
GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex {
375375
key_dt: key.get_type(schema)?,
376376
},
377-
GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange {
377+
GetFieldAccess::ListRange {
378+
start,
379+
stop,
380+
stride,
381+
} => GetFieldAccessSchema::ListRange {
378382
start_dt: start.get_type(schema)?,
379383
stop_dt: stop.get_type(schema)?,
384+
stride_dt: stride.get_type(schema)?,
380385
},
381386
}
382387
.get_accessed_field(&expr_dt)

datafusion/expr/src/field_util.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ pub enum GetFieldAccessSchema {
2828
NamedStructField { name: ScalarValue },
2929
/// Single list index, for example: `list[i]`
3030
ListIndex { key_dt: DataType },
31-
/// List range, for example `list[i:j]`
31+
/// List stride, for example `list[i:j:k]`
3232
ListRange {
3333
start_dt: DataType,
3434
stop_dt: DataType,
35+
stride_dt: DataType,
3536
},
3637
}
3738

@@ -85,13 +86,13 @@ impl GetFieldAccessSchema {
8586
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
8687
}
8788
}
88-
Self::ListRange{ start_dt, stop_dt } => {
89-
match (data_type, start_dt, stop_dt) {
90-
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
91-
(DataType::List(_), _, _) => plan_err!(
89+
Self::ListRange { start_dt, stop_dt, stride_dt } => {
90+
match (data_type, start_dt, stop_dt, stride_dt) {
91+
(DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
92+
(DataType::List(_), _, _, _) => plan_err!(
9293
"Only ints are valid as an indexed field in a list"
9394
),
94-
(other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
95+
(other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
9596
}
9697
}
9798
}

datafusion/expr/src/tree_node/expr.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ impl TreeNode for Expr {
5252
let expr = expr.as_ref();
5353
match field {
5454
GetFieldAccess::ListIndex {key} => vec![key.as_ref(), expr],
55-
GetFieldAccess::ListRange {start, stop} => {
56-
vec![start.as_ref(), stop.as_ref(), expr]
55+
GetFieldAccess::ListRange {start, stop, stride} => {
56+
vec![start.as_ref(), stop.as_ref(),stride.as_ref(), expr]
5757
}
5858
GetFieldAccess::NamedStructField { .. } => vec![expr],
5959
}

datafusion/physical-expr/src/expressions/get_indexed_field.rs

+64-21
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::PhysicalExpr;
2121
use datafusion_common::exec_err;
2222

2323
use crate::array_expressions::{array_element, array_slice};
24+
use crate::expressions::Literal;
2425
use crate::physical_expr::down_cast_any_ref;
2526
use arrow::{
2627
array::{Array, Scalar, StringArray},
@@ -43,10 +44,11 @@ pub enum GetFieldAccessExpr {
4344
NamedStructField { name: ScalarValue },
4445
/// Single list index, for example: `list[i]`
4546
ListIndex { key: Arc<dyn PhysicalExpr> },
46-
/// List range, for example `list[i:j]`
47+
/// List stride, for example `list[i:j:k]`
4748
ListRange {
4849
start: Arc<dyn PhysicalExpr>,
4950
stop: Arc<dyn PhysicalExpr>,
51+
stride: Arc<dyn PhysicalExpr>,
5052
},
5153
}
5254

@@ -55,8 +57,12 @@ impl std::fmt::Display for GetFieldAccessExpr {
5557
match self {
5658
GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name),
5759
GetFieldAccessExpr::ListIndex { key } => write!(f, "[{}]", key),
58-
GetFieldAccessExpr::ListRange { start, stop } => {
59-
write!(f, "[{}:{}]", start, stop)
60+
GetFieldAccessExpr::ListRange {
61+
start,
62+
stop,
63+
stride,
64+
} => {
65+
write!(f, "[{}:{}:{}]", start, stop, stride)
6066
}
6167
}
6268
}
@@ -76,12 +82,18 @@ impl PartialEq<dyn Any> for GetFieldAccessExpr {
7682
ListRange {
7783
start: start_lhs,
7884
stop: stop_lhs,
85+
stride: stride_lhs,
7986
},
8087
ListRange {
8188
start: start_rhs,
8289
stop: stop_rhs,
90+
stride: stride_rhs,
8391
},
84-
) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs),
92+
) => {
93+
start_lhs.eq(start_rhs)
94+
&& stop_lhs.eq(stop_rhs)
95+
&& stride_lhs.eq(stride_rhs)
96+
}
8597
(NamedStructField { .. }, ListIndex { .. } | ListRange { .. }) => false,
8698
(ListIndex { .. }, NamedStructField { .. } | ListRange { .. }) => false,
8799
(ListRange { .. }, NamedStructField { .. } | ListIndex { .. }) => false,
@@ -126,7 +138,32 @@ impl GetIndexedFieldExpr {
126138
start: Arc<dyn PhysicalExpr>,
127139
stop: Arc<dyn PhysicalExpr>,
128140
) -> Self {
129-
Self::new(arg, GetFieldAccessExpr::ListRange { start, stop })
141+
Self::new(
142+
arg,
143+
GetFieldAccessExpr::ListRange {
144+
start,
145+
stop,
146+
stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1))))
147+
as Arc<dyn PhysicalExpr>,
148+
},
149+
)
150+
}
151+
152+
/// Create a new [`GetIndexedFieldExpr`] for accessing the stride
153+
pub fn new_stride(
154+
arg: Arc<dyn PhysicalExpr>,
155+
start: Arc<dyn PhysicalExpr>,
156+
stop: Arc<dyn PhysicalExpr>,
157+
stride: Arc<dyn PhysicalExpr>,
158+
) -> Self {
159+
Self::new(
160+
arg,
161+
GetFieldAccessExpr::ListRange {
162+
start,
163+
stop,
164+
stride,
165+
},
166+
)
130167
}
131168

132169
/// Get the description of what field should be accessed
@@ -147,12 +184,15 @@ impl GetIndexedFieldExpr {
147184
GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex {
148185
key_dt: key.data_type(input_schema)?,
149186
},
150-
GetFieldAccessExpr::ListRange { start, stop } => {
151-
GetFieldAccessSchema::ListRange {
152-
start_dt: start.data_type(input_schema)?,
153-
stop_dt: stop.data_type(input_schema)?,
154-
}
155-
}
187+
GetFieldAccessExpr::ListRange {
188+
start,
189+
stop,
190+
stride,
191+
} => GetFieldAccessSchema::ListRange {
192+
start_dt: start.data_type(input_schema)?,
193+
stop_dt: stop.data_type(input_schema)?,
194+
stride_dt: stride.data_type(input_schema)?,
195+
},
156196
})
157197
}
158198
}
@@ -223,21 +263,24 @@ impl PhysicalExpr for GetIndexedFieldExpr {
223263
with utf8 indexes. Tried {dt:?} with {key:?} index"),
224264
}
225265
},
226-
GetFieldAccessExpr::ListRange{start, stop} => {
266+
GetFieldAccessExpr::ListRange { start, stop, stride } => {
227267
let start = start.evaluate(batch)?.into_array(batch.num_rows())?;
228268
let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?;
229-
match (array.data_type(), start.data_type(), stop.data_type()) {
230-
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[
231-
array, start, stop
232-
])?)),
233-
(DataType::List(_), start, stop) => exec_err!(
269+
let stride = stride.evaluate(batch)?.into_array(batch.num_rows())?;
270+
match (array.data_type(), start.data_type(), stop.data_type(), stride.data_type()) {
271+
(DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => {
272+
Ok(ColumnarValue::Array((array_slice(&[
273+
array, start, stop, stride
274+
]))?))
275+
},
276+
(DataType::List(_), start, stop, stride) => exec_err!(
234277
"get indexed field is only possible on lists with int64 indexes. \
235-
Tried with {start:?} and {stop:?} indices"),
236-
(dt, start, stop) => exec_err!(
278+
Tried with {start:?}, {stop:?} and {stride:?} indices"),
279+
(dt, start, stop, stride) => exec_err!(
237280
"get indexed field is only possible on lists with int64 indexes or struct \
238-
with utf8 indexes. Tried {dt:?} with {start:?} and {stop:?} indices"),
281+
with utf8 indexes. Tried {dt:?} with {start:?}, {stop:?} and {stride:?} indices"),
239282
}
240-
},
283+
}
241284
}
242285
}
243286

datafusion/physical-expr/src/planner.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -238,20 +238,19 @@ pub fn create_physical_expr(
238238
GetFieldAccess::ListIndex { key } => GetFieldAccessExpr::ListIndex {
239239
key: create_physical_expr(key, input_dfschema, execution_props)?,
240240
},
241-
GetFieldAccess::ListRange { start, stop } => {
242-
GetFieldAccessExpr::ListRange {
243-
start: create_physical_expr(
244-
start,
245-
input_dfschema,
246-
execution_props,
247-
)?,
248-
stop: create_physical_expr(
249-
stop,
250-
input_dfschema,
251-
execution_props,
252-
)?,
253-
}
254-
}
241+
GetFieldAccess::ListRange {
242+
start,
243+
stop,
244+
stride,
245+
} => GetFieldAccessExpr::ListRange {
246+
start: create_physical_expr(start, input_dfschema, execution_props)?,
247+
stop: create_physical_expr(stop, input_dfschema, execution_props)?,
248+
stride: create_physical_expr(
249+
stride,
250+
input_dfschema,
251+
execution_props,
252+
)?,
253+
},
255254
};
256255
Ok(Arc::new(GetIndexedFieldExpr::new(
257256
create_physical_expr(expr, input_dfschema, execution_props)?,

datafusion/proto/proto/datafusion.proto

+2
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ message ListIndex {
466466
message ListRange {
467467
LogicalExprNode start = 1;
468468
LogicalExprNode stop = 2;
469+
LogicalExprNode stride = 3;
469470
}
470471

471472
message GetIndexedField {
@@ -1773,6 +1774,7 @@ message ListIndexExpr {
17731774
message ListRangeExpr {
17741775
PhysicalExprNode start = 1;
17751776
PhysicalExprNode stop = 2;
1777+
PhysicalExprNode stride = 3;
17761778
}
17771779

17781780
message PhysicalGetIndexedFieldExprNode {

0 commit comments

Comments
 (0)