Skip to content

Commit 7f665e5

Browse files
committed
Returns min/max results from Delta Stats
Signed-off-by: Felipe Fujiy Pessoto <[email protected]>
1 parent 4074d29 commit 7f665e5

File tree

2 files changed

+842
-72
lines changed

2 files changed

+842
-72
lines changed

core/src/main/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuery.scala

+279-31
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,301 @@
1717
package org.apache.spark.sql.delta.perf
1818

1919
import org.apache.spark.sql.catalyst.InternalRow
20-
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
21-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Literal}
21+
import org.apache.spark.sql.catalyst.expressions.aggregate._
2222
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2323
import org.apache.spark.sql.catalyst.plans.logical._
24-
import org.apache.spark.sql.delta.DeltaTable
24+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
25+
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaTable, Snapshot}
2526
import org.apache.spark.sql.delta.files.TahoeLogFileIndex
2627
import org.apache.spark.sql.delta.stats.DeltaScanGenerator
27-
import org.apache.spark.sql.functions.{coalesce, col, count, lit, sum, when}
28+
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.types._
30+
31+
import scala.collection.immutable.HashSet
2832

2933
trait OptimizeMetadataOnlyDeltaQuery {
3034
def optimizeQueryWithMetadata(plan: LogicalPlan): LogicalPlan = {
3135
plan.transformUpWithSubqueries {
32-
case agg@CountStarDeltaTable(countValue) =>
33-
LocalRelation(agg.output, Seq(InternalRow(countValue)))
36+
case agg@AggregateDeltaTable(tahoeLogFileIndex) =>
37+
createLocalRelationPlan(agg, tahoeLogFileIndex)
3438
}
3539
}
3640

3741
protected def getDeltaScanGenerator(index: TahoeLogFileIndex): DeltaScanGenerator
3842

39-
object CountStarDeltaTable {
40-
def unapply(plan: Aggregate): Option[Long] = plan match {
41-
case Aggregate(
42-
Nil,
43-
Seq(Alias(AggregateExpression(Count(Seq(Literal(1, _))), Complete, false, None, _), _)),
44-
PhysicalOperation(_, Nil, DeltaTable(i: TahoeLogFileIndex))) if i.partitionFilters.isEmpty
45-
=> extractGlobalCount(i)
43+
protected def createLocalRelationPlan(
44+
plan: Aggregate,
45+
tahoeLogFileIndex: TahoeLogFileIndex): LogicalPlan = {
46+
val rowCount = extractGlobalCount(tahoeLogFileIndex)
47+
48+
if (rowCount.isDefined) {
49+
lazy val columnStats = extractGlobalColumnStats(tahoeLogFileIndex)
50+
51+
val aggregatedValues = plan.aggregateExpressions.collect {
52+
case Alias(AggregateExpression(
53+
Count(Seq(Literal(1, _))), Complete, false, None, _), _) =>
54+
rowCount.get
55+
case Alias(AggregateExpression(
56+
Min(minReference: AttributeReference), Complete, false, None, _), _)
57+
if columnStats.contains(minReference.name) &&
58+
// Avoid StructType, it is not supported by this optimization
59+
// Sanity check only. minReference would be GetStructType if it is a Struct column
60+
minReference.references.size == 1 &&
61+
minReference.references.head.dataType != StructType =>
62+
val value = if (minReference.dataType == DateType
63+
&& columnStats(minReference.name).min != null) {
64+
DateTimeUtils.fromJavaDate(
65+
columnStats(minReference.name).min.asInstanceOf[java.sql.Date])
66+
} else {
67+
columnStats(minReference.name).min
68+
}
69+
value
70+
case Alias(AggregateExpression(
71+
Max(maxReference: AttributeReference), Complete, false, None, _), _)
72+
if columnStats.contains(maxReference.name) &&
73+
// Avoid StructType, it is not supported by this optimization
74+
// Sanity check only. maxReference would be GetStructType if it is a Struct column
75+
maxReference.references.size == 1 &&
76+
maxReference.references.head.dataType != StructType =>
77+
val value = if (maxReference.dataType == DateType
78+
&& columnStats(maxReference.name).max != null) {
79+
DateTimeUtils.fromJavaDate(
80+
columnStats(maxReference.name).max.asInstanceOf[java.sql.Date])
81+
} else {
82+
columnStats(maxReference.name).max
83+
}
84+
value
85+
}
86+
87+
if (plan.aggregateExpressions.size == aggregatedValues.size) {
88+
val r = LocalRelation(
89+
plan.output,
90+
Seq(InternalRow.fromSeq(aggregatedValues)))
91+
r
92+
} else {
93+
plan
94+
}
95+
}
96+
else {
97+
plan
98+
}
99+
}
100+
101+
object AggregateDeltaTable {
102+
def unapply(plan: Aggregate): Option[TahoeLogFileIndex] = plan match {
103+
case Aggregate(Nil,
104+
seqTest: Seq[Alias],
105+
PhysicalOperation(projectList, Nil, DeltaTable(i: TahoeLogFileIndex)))
106+
if i.partitionFilters.isEmpty
107+
&& projectList.forall {
108+
case _: AttributeReference => true
109+
// Disable the optimization if Project is renaming the column
110+
// to avoid getting the incorrect column from stats, example:
111+
// SELECT MAX(Column2) FROM (SELECT Column1 AS Column2 FROM TableName)
112+
// We could create a mapping (alias -> actual name) to avoid the problem
113+
case a@Alias(_, _) => a.child.references.size == 1 &&
114+
a.name.equals(a.child.references.head.name)
115+
case _ => false
116+
}
117+
&& seqTest.forall {
118+
case Alias(AggregateExpression(
119+
Count(Seq(Literal(1, _))) | Min(_) | Max(_), Complete, false, None, _), _) => true
120+
case _ => false
121+
} =>
122+
Some(i)
123+
// When all columns are selected, there are no Project/PhysicalOperation
124+
case Aggregate(Nil,
125+
seqTest: Seq[Alias],
126+
DeltaTable(i: TahoeLogFileIndex))
127+
if i.partitionFilters.isEmpty
128+
&& seqTest.forall {
129+
case Alias(AggregateExpression(
130+
Count(Seq(Literal(1, _))) | Min(_) | Max(_), Complete, false, None, _), _) => true
131+
case _ => false
132+
} =>
133+
Some(i)
46134
case _ => None
47135
}
136+
}
137+
138+
/** Return the number of rows in the table or `None` if we cannot calculate it from stats */
139+
private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = {
140+
// account for deleted rows according to deletion vectors
141+
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0))
142+
val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords")
143+
val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil)
144+
.agg(
145+
sum(numLogicalRecords),
146+
// Calculate the number of files missing `numRecords`
147+
count(when(col("stats.numRecords").isNull, 1)))
148+
.first
149+
150+
// The count agg is never null. A non-zero value means we have incomplete stats; otherwise,
151+
// the sum agg is either null (for an empty table) or gives an accurate record count.
152+
if (row.getLong(1) > 0) return None
153+
val numRecords = if (row.isNullAt(0)) 0 else row.getLong(0)
154+
Some(numRecords)
155+
}
156+
157+
val columnStatsSupportedDataTypes: HashSet[DataType] = HashSet(
158+
ByteType,
159+
ShortType,
160+
IntegerType,
161+
LongType,
162+
FloatType,
163+
DoubleType,
164+
DateType)
165+
166+
case class DeltaColumnStat(
167+
min: Any,
168+
max: Any,
169+
nullCount: Option[Long],
170+
distinctCount: Option[Long])
171+
172+
def extractGlobalColumnStats(tahoeLogFileIndex: TahoeLogFileIndex):
173+
CaseInsensitiveMap[DeltaColumnStat] = {
174+
175+
// TODO Update this to work with DV (https://github.com/delta-io/delta/issues/1485)
176+
177+
val deltaScanGenerator = getDeltaScanGenerator(tahoeLogFileIndex)
178+
val snapshot = deltaScanGenerator.snapshotToScan
179+
180+
def extractGlobalColumnStatsDeltaLog(snapshot: Snapshot):
181+
Map[String, DeltaColumnStat] = {
182+
183+
val dataColumns = snapshot.statCollectionSchema
184+
.filter(col => columnStatsSupportedDataTypes.contains(col.dataType))
48185

49-
/** Return the number of rows in the table or `None` if we cannot calculate it from stats */
50-
private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = {
51-
// account for deleted rows according to deletion vectors
52-
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0))
53-
val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords")
54-
55-
val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil)
56-
.agg(
57-
sum(numLogicalRecords),
58-
// Calculate the number of files missing `numRecords`
59-
count(when(col("stats.numRecords").isNull, 1)))
60-
.first
61-
62-
// The count agg is never null. A non-zero value means we have incomplete stats; otherwise,
63-
// the sum agg is either null (for an empty table) or gives an accurate record count.
64-
if (row.getLong(1) > 0) return None
65-
val numRecords = if (row.isNullAt(0)) 0 else row.getLong(0)
66-
Some(numRecords)
186+
// Validate all the files has stats
187+
val filesStatsCount = deltaScanGenerator.filesWithStatsForScan(Nil).select(
188+
count(when(col("stats.numRecords").isNull, 1)).as("missingNumRecords"),
189+
count(when(col("stats.numRecords") > 0, 1)).as("countNonEmptyFiles")).head
190+
191+
val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0
192+
// DELETE operations creates AddFile records with 0 rows, and no column stats.
193+
// We can safely ignore it since there is no data.
194+
lazy val files = deltaScanGenerator.filesWithStatsForScan(Nil)
195+
.filter(col("stats.numRecords") > 0)
196+
val numFiles: Long = filesStatsCount.getAs[Long]("countNonEmptyFiles")
197+
lazy val statsMinMaxNullColumns = files.select(col("stats.*"))
198+
if (dataColumns.isEmpty
199+
|| !allRecordsHasStats
200+
|| numFiles == 0
201+
|| !statsMinMaxNullColumns.columns.contains("minValues")
202+
|| !statsMinMaxNullColumns.columns.contains("maxValues")
203+
|| !statsMinMaxNullColumns.columns.contains("nullCount")) {
204+
Map.empty
205+
} else {
206+
// dataColumns can contain columns without stats if dataSkippingNumIndexedCols
207+
// has been increased
208+
val columnsWithStats = files.select(
209+
col("stats.minValues.*"),
210+
col("stats.maxValues.*"),
211+
col("stats.nullCount.*"))
212+
.columns.groupBy(identity).mapValues(_.size)
213+
.filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount
214+
.map(x => x._1).toSet
215+
216+
// Creates a tuple with physical name to avoid recalculating it multiple times
217+
val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))
218+
.filter(x => columnsWithStats.contains(x._2))
219+
220+
val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
221+
val dataType = columnAndPhysicalName._1.dataType
222+
val physicalName = columnAndPhysicalName._2
223+
224+
Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"),
225+
col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"),
226+
col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName"))
227+
} ++ Seq(col(s"stats.numRecords").as(s"numRecords"))
228+
229+
val minMaxNullCountExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
230+
val physicalName = columnAndPhysicalName._2
231+
232+
// To validate if the column has stats we do two validation:
233+
// 1-) COUNT(nullCount.columnName) should be equals to numFiles,
234+
// since nullCount is always non-null.
235+
// 2-) The number of files with non-null min/max:
236+
// a. count(min.columnName)|count(max.columnName) +
237+
// the number of files where all rows are NULL:
238+
// b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords)
239+
// should be equals to numFiles
240+
Seq(
241+
s"""case when $numFiles = count(`nullCount.$physicalName`)
242+
| AND $numFiles = (count(`min.$physicalName`) + sum(case when
243+
| ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords
244+
| then 1 else 0 end))
245+
| AND $numFiles = (count(`max.$physicalName`) + sum(case when
246+
| ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords
247+
| then 1 else 0 end))
248+
| then TRUE else FALSE end as `complete_$physicalName`""".stripMargin,
249+
s"min(`min.$physicalName`) as `min_$physicalName`",
250+
s"max(`max.$physicalName`) as `max_$physicalName`",
251+
s"sum(`nullCount.$physicalName`) as `nullCount_$physicalName`")
252+
}
253+
254+
val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxNullCountExpr: _*).head
255+
256+
dataColumnsWithStats
257+
.filter(x => statsResults.getAs[Boolean](s"complete_${x._2}"))
258+
.map { columnAndPhysicalName =>
259+
val column = columnAndPhysicalName._1
260+
val physicalName = columnAndPhysicalName._2
261+
column.name ->
262+
DeltaColumnStat(
263+
statsResults.getAs(s"min_$physicalName"),
264+
statsResults.getAs(s"max_$physicalName"),
265+
Some(statsResults.getAs[Long](s"min_$physicalName")),
266+
None)
267+
}.toMap
268+
}
269+
}
270+
271+
def extractGlobalPartitionedColumnStatsDeltaLog(snapshot: Snapshot):
272+
Map[String, DeltaColumnStat] = {
273+
274+
val partitionedColumns = snapshot.metadata.partitionSchema
275+
.filter(x => columnStatsSupportedDataTypes.contains(x.dataType))
276+
.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))
277+
278+
if (partitionedColumns.isEmpty) {
279+
Map.empty
280+
} else {
281+
val partitionedColumnsValues = partitionedColumns.map { partitionedColumn =>
282+
val physicalName = partitionedColumn._2
283+
col(s"partitionValues.`$physicalName`")
284+
.cast(partitionedColumn._1.dataType).as(physicalName)
285+
}
286+
287+
val partitionedColumnsAgg = partitionedColumns.flatMap { partitionedColumn =>
288+
val physicalName = partitionedColumn._2
289+
290+
Seq(min(s"`$physicalName`").as(s"min_$physicalName"),
291+
max(s"`$physicalName`").as(s"max_$physicalName"),
292+
count_distinct(col(s"`$physicalName`")).as(s"nullCount_$physicalName"))
293+
}
294+
295+
val partitionedColumnsQuery = snapshot.allFiles
296+
.select(partitionedColumnsValues: _*)
297+
.agg(partitionedColumnsAgg.head, partitionedColumnsAgg.tail: _*)
298+
.head()
299+
300+
partitionedColumns.map { partitionedColumn =>
301+
val physicalName = partitionedColumn._2
302+
303+
partitionedColumn._1.name ->
304+
DeltaColumnStat(
305+
partitionedColumnsQuery.getAs(s"min_$physicalName"),
306+
partitionedColumnsQuery.getAs(s"max_$physicalName"),
307+
None,
308+
Some(partitionedColumnsQuery.getAs[Long](s"nullCount_$physicalName")))
309+
}.toMap
310+
}
67311
}
312+
313+
CaseInsensitiveMap(
314+
extractGlobalColumnStatsDeltaLog(snapshot).++
315+
(extractGlobalPartitionedColumnStatsDeltaLog(snapshot)))
68316
}
69317
}

0 commit comments

Comments
 (0)