@@ -21,6 +21,7 @@ use crate::PhysicalExpr;
21
21
use datafusion_common:: exec_err;
22
22
23
23
use crate :: array_expressions:: { array_element, array_slice} ;
24
+ use crate :: expressions:: Literal ;
24
25
use crate :: physical_expr:: down_cast_any_ref;
25
26
use arrow:: {
26
27
array:: { Array , Scalar , StringArray } ,
@@ -43,10 +44,11 @@ pub enum GetFieldAccessExpr {
43
44
NamedStructField { name : ScalarValue } ,
44
45
/// Single list index, for example: `list[i]`
45
46
ListIndex { key : Arc < dyn PhysicalExpr > } ,
46
- /// List range , for example `list[i:j]`
47
+ /// List stride , for example `list[i:j:k ]`
47
48
ListRange {
48
49
start : Arc < dyn PhysicalExpr > ,
49
50
stop : Arc < dyn PhysicalExpr > ,
51
+ stride : Arc < dyn PhysicalExpr > ,
50
52
} ,
51
53
}
52
54
@@ -55,8 +57,12 @@ impl std::fmt::Display for GetFieldAccessExpr {
55
57
match self {
56
58
GetFieldAccessExpr :: NamedStructField { name } => write ! ( f, "[{}]" , name) ,
57
59
GetFieldAccessExpr :: ListIndex { key } => write ! ( f, "[{}]" , key) ,
58
- GetFieldAccessExpr :: ListRange { start, stop } => {
59
- write ! ( f, "[{}:{}]" , start, stop)
60
+ GetFieldAccessExpr :: ListRange {
61
+ start,
62
+ stop,
63
+ stride,
64
+ } => {
65
+ write ! ( f, "[{}:{}:{}]" , start, stop, stride)
60
66
}
61
67
}
62
68
}
@@ -76,12 +82,18 @@ impl PartialEq<dyn Any> for GetFieldAccessExpr {
76
82
ListRange {
77
83
start : start_lhs,
78
84
stop : stop_lhs,
85
+ stride : stride_lhs,
79
86
} ,
80
87
ListRange {
81
88
start : start_rhs,
82
89
stop : stop_rhs,
90
+ stride : stride_rhs,
83
91
} ,
84
- ) => start_lhs. eq ( start_rhs) && stop_lhs. eq ( stop_rhs) ,
92
+ ) => {
93
+ start_lhs. eq ( start_rhs)
94
+ && stop_lhs. eq ( stop_rhs)
95
+ && stride_lhs. eq ( stride_rhs)
96
+ }
85
97
( NamedStructField { .. } , ListIndex { .. } | ListRange { .. } ) => false ,
86
98
( ListIndex { .. } , NamedStructField { .. } | ListRange { .. } ) => false ,
87
99
( ListRange { .. } , NamedStructField { .. } | ListIndex { .. } ) => false ,
@@ -126,7 +138,32 @@ impl GetIndexedFieldExpr {
126
138
start : Arc < dyn PhysicalExpr > ,
127
139
stop : Arc < dyn PhysicalExpr > ,
128
140
) -> Self {
129
- Self :: new ( arg, GetFieldAccessExpr :: ListRange { start, stop } )
141
+ Self :: new (
142
+ arg,
143
+ GetFieldAccessExpr :: ListRange {
144
+ start,
145
+ stop,
146
+ stride : Arc :: new ( Literal :: new ( ScalarValue :: Int64 ( Some ( 1 ) ) ) )
147
+ as Arc < dyn PhysicalExpr > ,
148
+ } ,
149
+ )
150
+ }
151
+
152
+ /// Create a new [`GetIndexedFieldExpr`] for accessing the stride
153
+ pub fn new_stride (
154
+ arg : Arc < dyn PhysicalExpr > ,
155
+ start : Arc < dyn PhysicalExpr > ,
156
+ stop : Arc < dyn PhysicalExpr > ,
157
+ stride : Arc < dyn PhysicalExpr > ,
158
+ ) -> Self {
159
+ Self :: new (
160
+ arg,
161
+ GetFieldAccessExpr :: ListRange {
162
+ start,
163
+ stop,
164
+ stride,
165
+ } ,
166
+ )
130
167
}
131
168
132
169
/// Get the description of what field should be accessed
@@ -147,12 +184,15 @@ impl GetIndexedFieldExpr {
147
184
GetFieldAccessExpr :: ListIndex { key } => GetFieldAccessSchema :: ListIndex {
148
185
key_dt : key. data_type ( input_schema) ?,
149
186
} ,
150
- GetFieldAccessExpr :: ListRange { start, stop } => {
151
- GetFieldAccessSchema :: ListRange {
152
- start_dt : start. data_type ( input_schema) ?,
153
- stop_dt : stop. data_type ( input_schema) ?,
154
- }
155
- }
187
+ GetFieldAccessExpr :: ListRange {
188
+ start,
189
+ stop,
190
+ stride,
191
+ } => GetFieldAccessSchema :: ListRange {
192
+ start_dt : start. data_type ( input_schema) ?,
193
+ stop_dt : stop. data_type ( input_schema) ?,
194
+ stride_dt : stride. data_type ( input_schema) ?,
195
+ } ,
156
196
} )
157
197
}
158
198
}
@@ -223,21 +263,24 @@ impl PhysicalExpr for GetIndexedFieldExpr {
223
263
with utf8 indexes. Tried {dt:?} with {key:?} index") ,
224
264
}
225
265
} ,
226
- GetFieldAccessExpr :: ListRange { start, stop} => {
266
+ GetFieldAccessExpr :: ListRange { start, stop, stride } => {
227
267
let start = start. evaluate ( batch) ?. into_array ( batch. num_rows ( ) ) ?;
228
268
let stop = stop. evaluate ( batch) ?. into_array ( batch. num_rows ( ) ) ?;
229
- match ( array. data_type ( ) , start. data_type ( ) , stop. data_type ( ) ) {
230
- ( DataType :: List ( _) , DataType :: Int64 , DataType :: Int64 ) => Ok ( ColumnarValue :: Array ( array_slice ( & [
231
- array, start, stop
232
- ] ) ?) ) ,
233
- ( DataType :: List ( _) , start, stop) => exec_err ! (
269
+ let stride = stride. evaluate ( batch) ?. into_array ( batch. num_rows ( ) ) ?;
270
+ match ( array. data_type ( ) , start. data_type ( ) , stop. data_type ( ) , stride. data_type ( ) ) {
271
+ ( DataType :: List ( _) , DataType :: Int64 , DataType :: Int64 , DataType :: Int64 ) => {
272
+ Ok ( ColumnarValue :: Array ( ( array_slice ( & [
273
+ array, start, stop, stride
274
+ ] ) ) ?) )
275
+ } ,
276
+ ( DataType :: List ( _) , start, stop, stride) => exec_err ! (
234
277
"get indexed field is only possible on lists with int64 indexes. \
235
- Tried with {start:?} and {stop :?} indices") ,
236
- ( dt, start, stop) => exec_err ! (
278
+ Tried with {start:?}, {stop:?} and {stride :?} indices") ,
279
+ ( dt, start, stop, stride ) => exec_err ! (
237
280
"get indexed field is only possible on lists with int64 indexes or struct \
238
- with utf8 indexes. Tried {dt:?} with {start:?} and {stop :?} indices") ,
281
+ with utf8 indexes. Tried {dt:?} with {start:?}, {stop:?} and {stride :?} indices") ,
239
282
}
240
- } ,
283
+ }
241
284
}
242
285
}
243
286
0 commit comments