Skip to content

Commit f798b76

Browse files
committed
Extract Count and Min/Max in a single method. Allows to extract Min/Max from partitioned columns even when COUNT is not available
Signed-off-by: Felipe Fujiy Pessoto <[email protected]>
1 parent 83c1671 commit f798b76

File tree

2 files changed

+168
-153
lines changed

2 files changed

+168
-153
lines changed

spark/src/main/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuery.scala

+147-153
Original file line numberDiff line numberDiff line change
@@ -55,66 +55,59 @@ trait OptimizeMetadataOnlyDeltaQuery {
5555
private def createLocalRelationPlan(
5656
plan: Aggregate,
5757
tahoeLogFileIndex: TahoeLogFileIndex): LogicalPlan = {
58-
val rowCount = extractGlobalCount(tahoeLogFileIndex)
59-
60-
if (rowCount.isDefined) {
61-
val aggColumnsNames = Set(extractMinMaxFieldNames(plan).map(_.toLowerCase(Locale.ROOT)) : _*)
62-
val columnStats = extractMinMaxFromDeltaLog(tahoeLogFileIndex, aggColumnsNames)
63-
64-
def checkStatsExists(attrRef: AttributeReference): Boolean = {
65-
columnStats.contains(attrRef.name) &&
66-
// Avoid StructType, it is not supported by this optimization
67-
// Sanity check only. If reference is nested column it would be GetStructType
68-
// instead of AttributeReference
69-
attrRef.references.size == 1 &&
70-
attrRef.references.head.dataType != StructType
71-
}
7258

73-
def convertValueIfRequired(attrRef: AttributeReference, value: Any): Any = {
74-
if (attrRef.dataType == DateType && value != null) {
75-
DateTimeUtils.fromJavaDate(value.asInstanceOf[Date])
76-
} else {
77-
value
78-
}
79-
}
59+
val aggColumnsNames = Set(extractMinMaxFieldNames(plan).map(_.toLowerCase(Locale.ROOT)) : _*)
60+
val (rowCount, columnStats) = extractCountMinMaxFromDeltaLog(tahoeLogFileIndex, aggColumnsNames)
8061

81-
val rewrittenAggregationValues = plan.aggregateExpressions.collect {
82-
case Alias(AggregateExpression(
83-
Count(Seq(Literal(1, _))), Complete, false, None, _), _) =>
84-
rowCount.get
85-
case Alias(tps@ToPrettyString(AggregateExpression(
86-
Count(Seq(Literal(1, _))), Complete, false, None, _), _), _) =>
87-
tps.copy(child = Literal(rowCount.get)).eval()
88-
case Alias(AggregateExpression(
89-
Min(minReference: AttributeReference), Complete, false, None, _), _)
90-
if checkStatsExists(minReference) =>
91-
convertValueIfRequired(minReference, columnStats(minReference.name).min)
92-
case Alias(tps@ToPrettyString(AggregateExpression(
93-
Min(minReference: AttributeReference), Complete, false, None, _), _), _)
94-
if checkStatsExists(minReference) =>
95-
val v = columnStats(minReference.name).min
96-
tps.copy(child = Literal(v)).eval()
97-
case Alias(AggregateExpression(
98-
Max(maxReference: AttributeReference), Complete, false, None, _), _)
99-
if checkStatsExists(maxReference) =>
100-
convertValueIfRequired(maxReference, columnStats(maxReference.name).max)
101-
case Alias(tps@ToPrettyString(AggregateExpression(
102-
Max(maxReference: AttributeReference), Complete, false, None, _), _), _)
103-
if checkStatsExists(maxReference) =>
104-
val v = columnStats(maxReference.name).max
105-
tps.copy(child = Literal(v)).eval()
106-
}
62+
def checkStatsExists(attrRef: AttributeReference): Boolean = {
63+
columnStats.contains(attrRef.name) &&
64+
// Avoid StructType, it is not supported by this optimization
65+
// Sanity check only. If reference is nested column it would be GetStructType
66+
// instead of AttributeReference
67+
attrRef.references.size == 1 && attrRef.references.head.dataType != StructType
68+
}
10769

108-
if (plan.aggregateExpressions.size == rewrittenAggregationValues.size) {
109-
val r = LocalRelation(
110-
plan.output,
111-
Seq(InternalRow.fromSeq(rewrittenAggregationValues)))
112-
r
70+
def convertValueIfRequired(attrRef: AttributeReference, value: Any): Any = {
71+
if (attrRef.dataType == DateType && value != null) {
72+
DateTimeUtils.fromJavaDate(value.asInstanceOf[Date])
11373
} else {
114-
plan
74+
value
11575
}
11676
}
117-
else {
77+
78+
val rewrittenAggregationValues = plan.aggregateExpressions.collect {
79+
case Alias(AggregateExpression(
80+
Count(Seq(Literal(1, _))), Complete, false, None, _), _) if rowCount.isDefined =>
81+
rowCount.get
82+
case Alias(tps@ToPrettyString(AggregateExpression(
83+
Count(Seq(Literal(1, _))), Complete, false, None, _), _), _) if rowCount.isDefined =>
84+
tps.copy(child = Literal(rowCount.get)).eval()
85+
case Alias(AggregateExpression(
86+
Min(minReference: AttributeReference), Complete, false, None, _), _)
87+
if checkStatsExists(minReference) =>
88+
convertValueIfRequired(minReference, columnStats(minReference.name).min)
89+
case Alias(tps@ToPrettyString(AggregateExpression(
90+
Min(minReference: AttributeReference), Complete, false, None, _), _), _)
91+
if checkStatsExists(minReference) =>
92+
val v = columnStats(minReference.name).min
93+
tps.copy(child = Literal(v)).eval()
94+
case Alias(AggregateExpression(
95+
Max(maxReference: AttributeReference), Complete, false, None, _), _)
96+
if checkStatsExists(maxReference) =>
97+
convertValueIfRequired(maxReference, columnStats(maxReference.name).max)
98+
case Alias(tps@ToPrettyString(AggregateExpression(
99+
Max(maxReference: AttributeReference), Complete, false, None, _), _), _)
100+
if checkStatsExists(maxReference) =>
101+
val v = columnStats(maxReference.name).max
102+
tps.copy(child = Literal(v)).eval()
103+
}
104+
105+
if (plan.aggregateExpressions.size == rewrittenAggregationValues.size) {
106+
val r = LocalRelation(
107+
plan.output,
108+
Seq(InternalRow.fromSeq(rewrittenAggregationValues)))
109+
r
110+
} else {
118111
plan
119112
}
120113
}
@@ -136,122 +129,114 @@ trait OptimizeMetadataOnlyDeltaQuery {
136129
}
137130
}
138131

139-
/** Return the number of rows in the table or `None` if we cannot calculate it from stats */
140-
private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = {
141-
// account for deleted rows according to deletion vectors
142-
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0))
143-
val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords")
144-
val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil)
145-
.agg(
146-
sum(numLogicalRecords),
147-
// Calculate the number of files missing `numRecords`
148-
count(when(col("stats.numRecords").isNull, 1)))
149-
.first
150-
151-
// The count agg is never null. A non-zero value means we have incomplete stats; otherwise,
152-
// the sum agg is either null (for an empty table) or gives an accurate record count.
153-
if (row.getLong(1) > 0) return None
154-
val numRecords = if (row.isNullAt(0)) 0 else row.getLong(0)
155-
Some(numRecords)
156-
}
157-
158132
/**
159133
* Min and max values from Delta Log stats or partitionValues.
160134
*/
161135
case class DeltaColumnStat(min: Any, max: Any)
162136

163-
private def extractMinMaxFromStats(
137+
private def extractCountMinMaxFromStats(
164138
deltaScanGenerator: DeltaScanGenerator,
165-
lowerCaseColumnNames: Set[String]): Map[String, DeltaColumnStat] = {
166-
139+
lowerCaseColumnNames: Set[String]): (Option[Long], Map[String, DeltaColumnStat]) = {
167140
// TODO Update this to work with DV (https://github.com/delta-io/delta/issues/1485)
141+
168142
val snapshot = deltaScanGenerator.snapshotToScan
169-
val dataColumns = snapshot.statCollectionPhysicalSchema.filter(col =>
170-
AggregateDeltaTable.isSupportedDataType(col.dataType) &&
171-
lowerCaseColumnNames.contains(col.name.toLowerCase(Locale.ROOT)))
172143

144+
// Count - account for deleted rows according to deletion vectors
145+
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0))
146+
val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords")
147+
148+
val filesWithStatsForScan = deltaScanGenerator.filesWithStatsForScan(Nil)
173149
// Validate all the files has stats
174-
lazy val filesStatsCount = deltaScanGenerator.filesWithStatsForScan(Nil).select(
150+
val filesStatsCount = filesWithStatsForScan.select(
151+
sum(numLogicalRecords).as("numLogicalRecords"),
175152
count(when(col("stats.numRecords").isNull, 1)).as("missingNumRecords"),
176153
count(when(col("stats.numRecords") > 0, 1)).as("countNonEmptyFiles")).head
177154

178-
lazy val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0
155+
// If any numRecords is null, we have incomplete stats;
156+
val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0
157+
if (!allRecordsHasStats) {
158+
return (None, Map.empty)
159+
}
160+
// the sum agg is either null (for an empty table) or gives an accurate record count.
161+
val numRecords = if (filesStatsCount.isNullAt(0)) 0 else filesStatsCount.getLong(0)
179162
lazy val numFiles: Long = filesStatsCount.getAs[Long]("countNonEmptyFiles")
180163

164+
val dataColumns = snapshot.statCollectionPhysicalSchema.filter(col =>
165+
AggregateDeltaTable.isSupportedDataType(col.dataType) &&
166+
lowerCaseColumnNames.contains(col.name.toLowerCase(Locale.ROOT)))
167+
181168
// DELETE operations creates AddFile records with 0 rows, and no column stats.
182169
// We can safely ignore it since there is no data.
183-
lazy val files = deltaScanGenerator.filesWithStatsForScan(Nil)
184-
.filter(col("stats.numRecords") > 0)
170+
lazy val files = filesWithStatsForScan.filter(col("stats.numRecords") > 0)
185171
lazy val statsMinMaxNullColumns = files.select(col("stats.*"))
186172
if (dataColumns.isEmpty
187173
|| !isTableDVFree(snapshot)
188-
|| !allRecordsHasStats
189174
|| numFiles == 0
190175
|| !statsMinMaxNullColumns.columns.contains("minValues")
191176
|| !statsMinMaxNullColumns.columns.contains("maxValues")
192177
|| !statsMinMaxNullColumns.columns.contains("nullCount")) {
193-
Map.empty
194-
} else {
195-
// dataColumns can contain columns without stats if dataSkippingNumIndexedCols
196-
// has been increased
197-
val columnsWithStats = files.select(
198-
col("stats.minValues.*"),
199-
col("stats.maxValues.*"),
200-
col("stats.nullCount.*"))
201-
.columns.groupBy(identity).mapValues(_.size)
202-
.filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount
203-
.map(x => x._1).toSet
204-
205-
// Creates a tuple with physical name to avoid recalculating it multiple times
206-
val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))
207-
.filter(x => columnsWithStats.contains(x._2))
208-
209-
val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
210-
val dataType = columnAndPhysicalName._1.dataType
211-
val physicalName = columnAndPhysicalName._2
212-
213-
Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"),
214-
col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"),
215-
col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName"))
216-
} ++ Seq(col(s"stats.numRecords").as(s"numRecords"))
178+
return (Some(numRecords), Map.empty)
179+
}
217180

218-
val minMaxExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
219-
val physicalName = columnAndPhysicalName._2
181+
// dataColumns can contain columns without stats if dataSkippingNumIndexedCols
182+
// has been increased
183+
val columnsWithStats = files.select(
184+
col("stats.minValues.*"),
185+
col("stats.maxValues.*"),
186+
col("stats.nullCount.*"))
187+
.columns.groupBy(identity).mapValues(_.size)
188+
.filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount
189+
.map(x => x._1).toSet
190+
191+
// Creates a tuple with physical name to avoid recalculating it multiple times
192+
val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))
193+
.filter(x => columnsWithStats.contains(x._2))
194+
195+
val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
196+
val dataType = columnAndPhysicalName._1.dataType
197+
val physicalName = columnAndPhysicalName._2
198+
199+
Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"),
200+
col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"),
201+
col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName"))
202+
} ++ Seq(col(s"stats.numRecords").as(s"numRecords"))
203+
204+
val minMaxExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
205+
val physicalName = columnAndPhysicalName._2
206+
207+
// To validate if the column has stats we do two validation:
208+
// 1-) COUNT(nullCount.columnName) should be equals to numFiles,
209+
// since nullCount is always non-null.
210+
// 2-) The number of files with non-null min/max:
211+
// a. count(min.columnName)|count(max.columnName) +
212+
// the number of files where all rows are NULL:
213+
// b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords)
214+
// should be equals to numFiles
215+
Seq(
216+
s"""case when $numFiles = count(`nullCount.$physicalName`)
217+
| AND $numFiles = (count(`min.$physicalName`) + sum(case when
218+
| ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords
219+
| then 1 else 0 end))
220+
| AND $numFiles = (count(`max.$physicalName`) + sum(case when
221+
| ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords
222+
| then 1 else 0 end))
223+
| then TRUE else FALSE end as `complete_$physicalName`""".stripMargin,
224+
s"min(`min.$physicalName`) as `min_$physicalName`",
225+
s"max(`max.$physicalName`) as `max_$physicalName`")
226+
}
220227

221-
// To validate if the column has stats we do two validation:
222-
// 1-) COUNT(nullCount.columnName) should be equals to numFiles,
223-
// since nullCount is always non-null.
224-
// 2-) The number of files with non-null min/max:
225-
// a. count(min.columnName)|count(max.columnName) +
226-
// the number of files where all rows are NULL:
227-
// b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords)
228-
// should be equals to numFiles
229-
Seq(
230-
s"""case when $numFiles = count(`nullCount.$physicalName`)
231-
| AND $numFiles = (count(`min.$physicalName`) + sum(case when
232-
| ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords
233-
| then 1 else 0 end))
234-
| AND $numFiles = (count(`max.$physicalName`) + sum(case when
235-
| ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords
236-
| then 1 else 0 end))
237-
| then TRUE else FALSE end as `complete_$physicalName`""".stripMargin,
238-
s"min(`min.$physicalName`) as `min_$physicalName`",
239-
s"max(`max.$physicalName`) as `max_$physicalName`")
240-
}
228+
val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxExpr: _*).head
241229

242-
val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxExpr: _*).head
243-
244-
dataColumnsWithStats
245-
.filter(x => statsResults.getAs[Boolean](s"complete_${x._2}"))
246-
.map { columnAndPhysicalName =>
247-
val column = columnAndPhysicalName._1
248-
val physicalName = columnAndPhysicalName._2
249-
column.name ->
250-
DeltaColumnStat(
251-
statsResults.getAs(s"min_$physicalName"),
252-
statsResults.getAs(s"max_$physicalName"))
253-
}.toMap
254-
}
230+
(Some(numRecords), dataColumnsWithStats
231+
.filter(x => statsResults.getAs[Boolean](s"complete_${x._2}"))
232+
.map { columnAndPhysicalName =>
233+
val column = columnAndPhysicalName._1
234+
val physicalName = columnAndPhysicalName._2
235+
column.name ->
236+
DeltaColumnStat(
237+
statsResults.getAs(s"min_$physicalName"),
238+
statsResults.getAs(s"max_$physicalName"))
239+
}.toMap)
255240
}
256241

257242
private def extractMinMaxFromPartitionValue(
@@ -295,21 +280,28 @@ trait OptimizeMetadataOnlyDeltaQuery {
295280
}
296281
}
297282

298-
private def extractMinMaxFromDeltaLog(
283+
/**
284+
* Extract the Count, Min and Max values from Delta Log stats and partitionValues.
285+
* The first field is the rows count in the table or `None` if we cannot calculate it from stats
286+
* If the column is not partitioned, the values are extracted from stats when it exists.
287+
* If the column is partitioned, the values are extracted from partitionValues.
288+
*/
289+
private def extractCountMinMaxFromDeltaLog(
299290
tahoeLogFileIndex: TahoeLogFileIndex,
300291
lowerCaseColumnNames: Set[String]):
301-
CaseInsensitiveMap[DeltaColumnStat] = {
302-
val deltaScanGenerator = getDeltaScanGenerator(tahoeLogFileIndex)
303-
val snapshot = deltaScanGenerator.snapshotToScan
304-
val columnFromStats = extractMinMaxFromStats(deltaScanGenerator, lowerCaseColumnNames)
292+
(Option[Long], CaseInsensitiveMap[DeltaColumnStat]) = {
293+
val deltaScanGen = getDeltaScanGenerator(tahoeLogFileIndex)
294+
val (rowCount, columnStats) = extractCountMinMaxFromStats(deltaScanGen, lowerCaseColumnNames)
305295

306-
if(lowerCaseColumnNames.equals(columnFromStats.keySet)) {
307-
CaseInsensitiveMap(columnFromStats)
296+
val minMaxValues = if (lowerCaseColumnNames.equals(columnStats.keySet)) {
297+
CaseInsensitiveMap(columnStats)
308298
} else {
309299
CaseInsensitiveMap(
310-
columnFromStats.++
311-
(extractMinMaxFromPartitionValue(snapshot, lowerCaseColumnNames)))
300+
columnStats.++
301+
(extractMinMaxFromPartitionValue(deltaScanGen.snapshotToScan, lowerCaseColumnNames)))
312302
}
303+
304+
(rowCount, minMaxValues)
313305
}
314306

315307
object AggregateDeltaTable {
@@ -322,7 +314,9 @@ trait OptimizeMetadataOnlyDeltaQuery {
322314
dataType.isInstanceOf[DateType]
323315
}
324316

325-
def getAggFunctionOptimizable(aggExpr: AggregateExpression): Option[DeclarativeAggregate] = {
317+
private def getAggFunctionOptimizable(
318+
aggExpr: AggregateExpression): Option[DeclarativeAggregate] = {
319+
326320
aggExpr match {
327321
case AggregateExpression(
328322
c@Count(Seq(Literal(1, _))), Complete, false, None, _) =>

spark/src/test/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuerySuite.scala

+21
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,27 @@ class OptimizeMetadataOnlyDeltaQuerySuite
506506
}
507507
}
508508

509+
test("min-max - partitioned column stats disabled") {
510+
withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "false") {
511+
val tableName = "TestPartitionedNoStats"
512+
513+
spark.sql(s"CREATE TABLE $tableName (Column1 INT, Column2 INT)" +
514+
" USING DELTA PARTITIONED BY (Column2)")
515+
516+
spark.sql(s"INSERT INTO $tableName (Column1, Column2) VALUES (1, 3);")
517+
spark.sql(s"INSERT INTO $tableName (Column1, Column2) VALUES (2, 4);")
518+
519+
//Has no stats, including COUNT
520+
checkOptimizationIsNotTriggered(
521+
s"SELECT COUNT(*), MIN(Column2), MAX(Column2) FROM $tableName")
522+
523+
//Should work for partitioned columns even without stats
524+
checkResultsAndOptimizedPlan(
525+
s"SELECT MIN(Column2), MAX(Column2) FROM $tableName",
526+
"LocalRelation [none#0, none#1]")
527+
}
528+
}
529+
509530
test("min-max - recompute column missing stats") {
510531
val tableName = "TestRecomputeMissingStat"
511532

0 commit comments

Comments
 (0)