From 22bd1781691795b28661899ddc94e8024b07ab41 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 25 Jan 2024 13:58:06 +0800 Subject: [PATCH] Support observe hint --- .../org/apache/kyuubi/sql/KyuubiSQLConf.scala | 8 ++ .../kyuubi/sql/KyuubiSparkSQLExtension.scala | 3 + .../sql/observe/ResolveObserveHints.scala | 105 ++++++++++++++++++ .../observe/ResolveObserveHintsSuite.scala | 50 +++++++++ 4 files changed, 166 insertions(+) create mode 100644 extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/observe/ResolveObserveHints.scala create mode 100644 extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/observe/ResolveObserveHintsSuite.scala diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala index 7c4e8d631ef..8ed7698263e 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala @@ -296,4 +296,12 @@ object KyuubiSQLConf { .version("1.9.0") .booleanConf .createWithDefault(true) + + val OBSERVE_HINT_ENABLE = + buildConf("spark.sql.optimizer.observeHint.enabled") + .doc(s"Provide OBSERVE Hint to create an observer to collect aggregated metrics." + + s" The OBSERVE Hint Syntax: /*+ OBSERVE(name, exprs) */.") + .version("1.9.0") + .booleanConf + .createWithDefault(false) } diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index fd11fb5f579..c6349258e5a 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,6 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions} +import org.apache.kyuubi.sql.observe.ResolveObserveHints import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, KyuubiUnsupportedOperationsCheck, MaxScanStrategy} // scalastyle:off line.size.limit @@ -32,6 +33,8 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { override def apply(extensions: SparkSessionExtensions): Unit = { KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions) + extensions.injectResolutionRule(_ => ResolveObserveHints) + extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource) extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive) extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/observe/ResolveObserveHints.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/observe/ResolveObserveHints.scala new file mode 100644 index 00000000000..11ca2da50cc --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/observe/ResolveObserveHints.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kyuubi.sql.observe + +import java.util.Locale +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAlias} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Generator, NamedExpression, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LogicalPlan, UnresolvedHint} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_HINT +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression + +import org.apache.kyuubi.sql.KyuubiSQLConf.OBSERVE_HINT_ENABLE + +/** + * A rule to resolve the OBSERVE hint. + * OBSERVE hint usage like: /*+ OBSERVE('name', exprs) */ + */ +object ResolveObserveHints extends Rule[LogicalPlan] { + + private val OBSERVE_HINT_NAME = "OBSERVE" + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(OBSERVE_HINT_ENABLE)) { + return plan + } + plan.resolveOperatorsWithPruning( + _.containsPattern(UNRESOLVED_HINT)) { + case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match { + case OBSERVE_HINT_NAME => + val (name, exprs) = hint.parameters match { + case Seq(StringLiteral(name), exprs @ _*) => (name, exprs) + case Seq(exprs @ _*) => (nextObserverName(), exprs) + } + + val invalidParams = exprs.filter(!_.isInstanceOf[Expression]) + if (invalidParams.nonEmpty) { + val hintName = hint.name.toUpperCase(Locale.ROOT) + throw invalidHintParameterError(hintName, invalidParams) + } + + // named exprs, copy from org.apache.spark.sql.Column.named method + val namedExprs = exprs.map { + case expr: NamedExpression => expr + // Leave an unaliased generator with an empty list of names since the analyzer will + // generate the correct defaults after the nested expression's type has been resolved. + case g: Generator => MultiAlias(g, Nil) + + // If we have a top level Cast, there is a chance to give it a better alias, + // if there is a NamedExpression under this Cast. + case c: Cast => + c.transformUp { + case c @ Cast(_: NamedExpression, _, _, _) => UnresolvedAlias(c) + } match { + case ne: NamedExpression => ne + case _ => UnresolvedAlias(c, Some(generateAlias)) + } + + case expr: Expression => UnresolvedAlias(expr, Some(generateAlias)) + } + + CollectMetrics(name, namedExprs, hint.child) + case _ => hint + } + } + } + + private val id = new AtomicLong(0) + private def nextObserverName(): String = s"OBSERVER_${id.getAndIncrement()}" + + private def invalidHintParameterError(hintName: String, invalidParams: Seq[Any]): Throwable = { + new AnalysisException( + errorClass = "_LEGACY_ERROR_TEMP_1047", + messageParameters = Map( + "hintName" -> hintName, + "invalidParams" -> invalidParams.mkString(", "))) + } + + private def generateAlias(e: Expression): String = { + e match { + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + a.aggregateFunction.toString + case expr => toPrettySQL(expr) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/observe/ResolveObserveHintsSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/observe/ResolveObserveHintsSuite.scala new file mode 100644 index 00000000000..bcbbe784322 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/observe/ResolveObserveHintsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.observe + +import org.apache.spark.sql.{KyuubiSparkSQLExtensionTest, QueryTest, Row} + +import org.apache.kyuubi.sql.KyuubiSQLConf.OBSERVE_HINT_ENABLE + +class ResolveObserveHintsSuite extends KyuubiSparkSQLExtensionTest { + + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("test observe hint") { + withSQLConf(OBSERVE_HINT_ENABLE.key -> "true") { + val sqlText = + s""" + | SELECT /*+ OBSERVE('observer3', sum(tt2.c3), count(1)) */ * + | FROM + | (SELECT /*+ OBSERVE('observer1', sum(c1), count(1)) */ * from t1) tt1 + | join + | (SELECT /*+ OBSERVE('observer2', sum(c1), count(1)) */ c1, c1 * 2 as c3 from t2) tt2 + | on tt1.c1 = tt2.c1 + |""".stripMargin + val df = spark.sql(sqlText) + df.collect() + val observedMetrics = df.queryExecution.observedMetrics + assert(observedMetrics.size == 3) + QueryTest.sameRows(Seq(observedMetrics("observer1")), Seq(Row(5050, 100))) + QueryTest.sameRows(Seq(observedMetrics("observer2")), Seq(Row(55, 10))) + QueryTest.sameRows(Seq(observedMetrics("observer3")), Seq(Row(110, 10))) + } + } +}