From aaffefbf40a791e7f408f1687c84c2124923d54c Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 8 Mar 2025 21:16:27 -0800 Subject: [PATCH 1/2] Rename AstVisitor visitRowField to visitRowDataTypeField --- .../src/main/java/io/trino/sql/ExpressionFormatter.java | 2 +- .../src/main/java/io/trino/sql/tree/AstVisitor.java | 2 +- .../src/main/java/io/trino/sql/tree/RowDataType.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index cb2fe276fd23..c1f283a1f548 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -711,7 +711,7 @@ protected String visitRowDataType(RowDataType node, Void context) } @Override - protected String visitRowField(RowDataType.Field node, Void context) + protected String visitRowDataTypeField(RowDataType.Field node, Void context) { StringBuilder result = new StringBuilder(); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index 1da84c288c6f..981bded97775 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -952,7 +952,7 @@ protected R visitGenericDataType(GenericDataType node, C context) return visitDataType(node, context); } - protected R visitRowField(RowDataType.Field node, C context) + protected R visitRowDataTypeField(RowDataType.Field node, C context) { return visitNode(node, context); } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/RowDataType.java b/core/trino-parser/src/main/java/io/trino/sql/tree/RowDataType.java index e085db3c14f8..a21a38a8dada 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/RowDataType.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/RowDataType.java @@ -120,7 +120,7 @@ public List getChildren() @Override protected R accept(AstVisitor visitor, C context) { - return visitor.visitRowField(this, context); + return visitor.visitRowDataTypeField(this, context); } @Override From f0fb5757c6c8d5dc2c9cd60c3e94bc47c083fc2f Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Tue, 11 Mar 2025 19:23:35 -0700 Subject: [PATCH 2/2] Allow field name declaration in row literal Add support for `row(a 1, b 2)` instead of the much more complex `cast(row(1, 2) as row(a integer, b integer))`. --- .../antlr4/io/trino/grammar/sql/SqlBase.g4 | 6 +- .../sql/analyzer/AggregationAnalyzer.java | 8 +- .../sql/analyzer/ExpressionAnalyzer.java | 6 +- .../trino/sql/analyzer/StatementAnalyzer.java | 2 +- .../io/trino/sql/ir/ExpressionFormatter.java | 17 ++- .../trino/sql/ir/ExpressionTreeRewriter.java | 2 +- .../src/main/java/io/trino/sql/ir/Row.java | 30 ++++- .../ir/optimizer/IrExpressionOptimizer.java | 2 +- .../sql/ir/optimizer/rule/EvaluateRow.java | 2 +- .../io/trino/sql/planner/RelationPlanner.java | 14 ++- .../io/trino/sql/planner/TranslationMap.java | 15 ++- .../rule/ExpressionRewriteRuleSet.java | 16 +-- .../iterative/rule/PushCastIntoRow.java | 58 ++++----- ...mCorrelatedSingleRowSubqueryToProject.java | 2 +- .../trino/sql/rewrite/ShowStatsRewrite.java | 36 +++--- .../rule/TestExpressionRewriteRuleSet.java | 2 +- .../iterative/rule/TestPushCastIntoRow.java | 14 ++- .../java/io/trino/sql/query/TestGroupBy.java | 2 +- .../io/trino/sql/ExpressionFormatter.java | 16 ++- .../src/main/java/io/trino/sql/QueryUtil.java | 3 +- .../main/java/io/trino/sql/SqlFormatter.java | 12 +- .../java/io/trino/sql/parser/AstBuilder.java | 15 ++- .../java/io/trino/sql/tree/AstVisitor.java | 5 + .../sql/tree/DefaultTraversalVisitor.java | 14 ++- .../sql/tree/ExpressionTreeRewriter.java | 10 +- .../src/main/java/io/trino/sql/tree/Row.java | 111 ++++++++++++++++-- .../io/trino/sql/TestExpressionFormatter.java | 37 ++++++ .../io/trino/sql/parser/TestSqlParser.java | 16 +-- .../parser/TestSqlParserErrorHandling.java | 4 +- .../io/trino/testing/BaseConnectorTest.java | 4 +- 30 files changed, 353 insertions(+), 128 deletions(-) diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index 9334c12d0d1c..532505c5d4e5 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -574,7 +574,7 @@ primaryExpression | QUESTION_MARK #parameter | POSITION '(' valueExpression IN valueExpression ')' #position | '(' expression (',' expression)+ ')' #rowConstructor - | ROW '(' expression (',' expression)* ')' #rowConstructor + | ROW '(' fieldConstructor (',' fieldConstructor)* ')' #rowConstructor | name=LISTAGG '(' setQuantifier? expression (',' string)? (ON OVERFLOW listAggOverflowBehavior)? ')' (WITHIN GROUP '(' ORDER BY sortItem (',' sortItem)* ')') @@ -646,6 +646,10 @@ primaryExpression ')' #jsonArray ; +fieldConstructor + : expression (AS? identifier)? + ; + jsonPathInvocation : jsonValueExpression ',' path=string (AS pathName=identifier)? diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java index 42343b70df13..ba4061fc0c72 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java @@ -701,10 +701,16 @@ protected Boolean visitTryExpression(TryExpression node, Void context) @Override protected Boolean visitRow(Row node, Void context) { - return node.getItems().stream() + return node.getFields().stream() .allMatch(item -> process(item, context)); } + @Override + protected Boolean visitRowField(Row.Field node, Void context) + { + return process(node.getExpression(), context); + } + @Override protected Boolean visitParameter(Parameter node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index e7e050f32b7b..4ecebe3a3202 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -679,11 +679,11 @@ public Type process(Node node, @Nullable Context context) @Override protected Type visitRow(Row node, Context context) { - List types = node.getItems().stream() - .map(child -> process(child, context)) + List fields = node.getFields().stream() + .map(field -> new RowType.Field(field.getName().map(Identifier::getCanonicalValue), process(field.getExpression(), context))) .collect(toImmutableList()); - Type type = RowType.anonymous(types); + Type type = RowType.from(fields); return setExpressionType(node, type); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 33de0fcc09c6..ebc8f859e907 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -3929,7 +3929,7 @@ protected Scope visitValues(Values node, Optional scope) // TODO coerce the whole Row and add an Optimizer rule that converts CAST(ROW(...) AS ...) into ROW(CAST(...), CAST(...), ...). // The rule would also handle Row-type expressions that were specified as CAST(ROW). It should support multiple casts over a ROW. for (int i = 0; i < actualType.getTypeParameters().size(); i++) { - Expression item = ((Row) row).getItems().get(i); + Expression item = ((Row) row).getFields().get(i).getExpression(); Type actualItemType = actualType.getTypeParameters().get(i); Type expectedItemType = commonSuperType.getTypeParameters().get(i); if (!actualItemType.equals(expectedItemType)) { diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java index fac2a33f6eba..8dfcbb2cad75 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java @@ -16,6 +16,7 @@ import com.google.common.base.CharMatcher; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import io.trino.spi.type.RowType; import io.trino.sql.planner.Symbol; import java.util.List; @@ -67,9 +68,19 @@ protected String visitArray(Array node, Void context) @Override protected String visitRow(Row node, Void context) { - return node.items().stream() - .map(child -> process(child, context)) - .collect(joining(", ", "ROW (", ")")); + List fieldTypes = ((RowType) node.type()).getFields(); + + StringBuilder builder = new StringBuilder(); + builder.append("ROW ("); + for (int i = 0; i < fieldTypes.size(); i++) { + if (i > 0) { + builder.append(", "); + } + builder.append(node.items().get(i).accept(this, context)); + fieldTypes.get(i).getName().ifPresent(name -> builder.append(" AS ").append(name)); + } + builder.append(")"); + return builder.toString(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java index 5cbcc1c48a3f..7ac1cc71fdbb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java @@ -110,7 +110,7 @@ protected Expression visitRow(Row node, Context context) List items = rewrite(node.items(), context); if (!sameElements(node.items(), items)) { - return new Row(items); + return new Row(items, node.type()); } return node; diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/Row.java b/core/trino-main/src/main/java/io/trino/sql/ir/Row.java index d7be312cff17..f34e14f5654d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/Row.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/Row.java @@ -16,15 +16,15 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.collect.ImmutableList; import io.trino.spi.type.RowType; -import io.trino.spi.type.Type; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static java.util.Objects.requireNonNull; @JsonSerialize -public record Row(List items) +public record Row(List items, RowType type) implements Expression { public Row @@ -33,10 +33,9 @@ public record Row(List items) items = ImmutableList.copyOf(items); } - @Override - public Type type() + public Row(List items) { - return RowType.anonymous(items.stream().map(Expression::type).collect(Collectors.toList())); + this(items, RowType.anonymous(items.stream().map(Expression::type).toList())); } @Override @@ -60,4 +59,25 @@ public String toString() .collect(Collectors.joining(", ")) + ")"; } + + @JsonSerialize + public record Field(Optional name, Expression value) + { + public Field + { + requireNonNull(name, "name is null"); + requireNonNull(value, "value is null"); + } + + public static Field anonymousField(Expression value) + { + return new Field(Optional.empty(), value); + } + + @Override + public String toString() + { + return name.map(n -> n + " " + value).orElseGet(value::toString); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/IrExpressionOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/IrExpressionOptimizer.java index 5357a7ccb007..816bb57c58f7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/IrExpressionOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/IrExpressionOptimizer.java @@ -192,7 +192,7 @@ private Optional processChildren(Expression expression, Session sess case Logical logical -> process(logical.terms(), session, bindings).map(arguments -> new Logical(logical.operator(), arguments)); case Call call -> process(call.arguments(), session, bindings).map(arguments -> new Call(call.function(), arguments)); case Array array -> process(array.elements(), session, bindings).map(elements -> new Array(array.elementType(), elements)); - case Row row -> process(row.items(), session, bindings).map(fields -> new Row(fields)); + case Row row -> process(row.items(), session, bindings).map(fields -> new Row(fields, row.type())); case Between between -> { Optional value = process(between.value(), session, bindings); Optional min = process(between.min(), session, bindings); diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/rule/EvaluateRow.java b/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/rule/EvaluateRow.java index 115a98926457..e84bc8efc235 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/rule/EvaluateRow.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/optimizer/rule/EvaluateRow.java @@ -37,7 +37,7 @@ public class EvaluateRow @Override public Optional apply(Expression expression, Session session, Map bindings) { - if (!(expression instanceof Row(List fields)) || !fields.stream().allMatch(Constant.class::isInstance)) { + if (!(expression instanceof Row(List fields, RowType _)) || !fields.stream().allMatch(Constant.class::isInstance)) { return Optional.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 2a6610c800d3..91904184cdbc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -1764,10 +1764,16 @@ protected RelationPlan visitValues(Values node, Void context) ImmutableList.Builder rows = ImmutableList.builder(); for (io.trino.sql.tree.Expression row : node.getRows()) { - if (row instanceof io.trino.sql.tree.Row) { - rows.add(new Row(((io.trino.sql.tree.Row) row).getItems().stream() - .map(item -> coerceIfNecessary(analysis, item, translationMap.rewrite(item))) - .collect(toImmutableList()))); + if (row instanceof io.trino.sql.tree.Row astRow) { + ImmutableList.Builder fields = ImmutableList.builder(); + ImmutableList.Builder typeFields = ImmutableList.builder(); + for (int i = 0; i < astRow.getFields().size(); i++) { + io.trino.sql.tree.Row.Field astField = astRow.getFields().get(i); + Expression expression = coerceIfNecessary(analysis, astField.getExpression(), translationMap.rewrite(astField.getExpression())); + fields.add(expression); + typeFields.add(new RowType.Field(astField.getName().map(Identifier::getCanonicalValue), expression.type())); + } + rows.add(new Row(fields.build(), RowType.from(typeFields.build()))); } else if (analysis.getType(row) instanceof RowType) { rows.add(coerceIfNecessary(analysis, row, translationMap.rewrite(row))); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index 2ba0abbae52e..e3076758c963 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -150,6 +150,7 @@ import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.IrExpressions.ifExpression; import static io.trino.sql.ir.IrExpressions.not; +import static io.trino.sql.planner.QueryPlanner.coerceIfNecessary; import static io.trino.sql.planner.ScopeAware.scopeAwareKey; import static io.trino.sql.tree.JsonQuery.EmptyOrErrorBehavior.ERROR; import static io.trino.sql.tree.JsonQuery.QuotesBehavior.KEEP; @@ -545,11 +546,17 @@ private io.trino.sql.ir.Expression translate(NotExpression expression) return not(plannerContext.getMetadata(), translateExpression(expression.getValue())); } - private io.trino.sql.ir.Expression translate(Row expression) + private io.trino.sql.ir.Expression translate(Row row) { - return new io.trino.sql.ir.Row(expression.getItems().stream() - .map(this::translateExpression) - .collect(toImmutableList())); + ImmutableList.Builder fields = ImmutableList.builder(); + ImmutableList.Builder typeFields = ImmutableList.builder(); + for (int i = 0; i < row.getFields().size(); i++) { + io.trino.sql.tree.Row.Field field = row.getFields().get(i); + io.trino.sql.ir.Expression expression = coerceIfNecessary(analysis, field.getExpression(), translateExpression(field.getExpression())); + fields.add(expression); + typeFields.add(new RowType.Field(field.getName().map(Identifier::getCanonicalValue), expression.type())); + } + return new io.trino.sql.ir.Row(fields.build(), RowType.from(typeFields.build())); } private io.trino.sql.ir.Expression translate(ComparisonExpression expression) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index 1bd409a8aadc..49c0787ace20 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -300,18 +300,20 @@ public Result apply(ValuesNode valuesNode, Captures captures, Context context) boolean anyRewritten = false; ImmutableList.Builder rows = ImmutableList.builder(); - for (Expression row : valuesNode.getRows().get()) { + for (Expression original : valuesNode.getRows().get()) { Expression rewritten; - if (row instanceof Row) { + if (original instanceof Row row) { // preserve the structure of row - rewritten = new Row(((Row) row).items().stream() - .map(item -> rewriter.rewrite(item, context)) - .collect(toImmutableList())); + rewritten = new Row( + row.items().stream() + .map(item -> rewriter.rewrite(item, context)) + .collect(toImmutableList()), + row.type()); } else { - rewritten = rewriter.rewrite(row, context); + rewritten = rewriter.rewrite(original, context); } - if (!row.equals(rewritten)) { + if (!original.equals(rewritten)) { anyRewritten = true; } rows.add(rewritten); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java index 487d2449261b..95c9ff9110ca 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java @@ -20,73 +20,57 @@ import io.trino.sql.ir.Expression; import io.trino.sql.ir.ExpressionTreeRewriter; import io.trino.sql.ir.Row; -import io.trino.type.UnknownType; /** * Transforms expressions of the form * *
  *  CAST(
- *      CAST(
- *          ROW(x, y)
- *          AS row(f1 type1, f2 type2))
- *      AS row(g1 type3, g2 type4))
+ *      ROW(x, y)
+ *      AS row(f1 type1, f2 type2))
  * 
* * to * *
- *  CAST(
- *      ROW(
- *          CAST(x AS type1),
- *          CAST(y AS type2))
- *      AS row(g1 type3, g2 type4))
+ *  ROW(
+ *      CAST(x AS type1) as f1,
+ *      CAST(y AS type2) as f2)
  * 
- * - * Note: it preserves the top-level CAST if the row type has field names because the names are needed by the ROW to JSON cast - * TODO: ideally, the types involved in ROW to JSON cast should be captured at analysis time and - * remain fixed for the duration of the optimization process so as to have flexibility in terms - * of removing field names, which are irrelevant in the IR */ public class PushCastIntoRow extends ExpressionRewriteRuleSet { public PushCastIntoRow() { - super((expression, context) -> ExpressionTreeRewriter.rewriteWith(new Rewriter(), expression, false)); + super((expression, context) -> ExpressionTreeRewriter.rewriteWith(new Rewriter(), expression, null)); } private static class Rewriter - extends io.trino.sql.ir.ExpressionRewriter + extends io.trino.sql.ir.ExpressionRewriter { @Override - public Expression rewriteCast(Cast node, Boolean inRowCast, ExpressionTreeRewriter treeRewriter) + public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter treeRewriter) { - if (!(node.type() instanceof RowType type)) { - return treeRewriter.defaultRewrite(node, false); + if (!(node.type() instanceof RowType castToType)) { + return treeRewriter.defaultRewrite(node, null); } - // if inRowCast == true or row is anonymous, we're free to push Cast into Row. An enclosing CAST(... AS ROW) will take care of preserving field names - // otherwise, apply recursively with inRowCast == true and don't push this one - - if (inRowCast || type.getFields().stream().allMatch(field -> field.getName().isEmpty())) { - Expression value = treeRewriter.rewrite(node.expression(), true); - - if (value instanceof Row row) { - ImmutableList.Builder items = ImmutableList.builder(); - for (int i = 0; i < row.items().size(); i++) { - Expression item = row.items().get(i); - Type itemType = type.getFields().get(i).getType(); - if (!(itemType instanceof UnknownType)) { - item = new Cast(item, itemType); - } - items.add(item); + Expression value = treeRewriter.rewrite(node.expression(), null); + if (value instanceof Row(java.util.List expressions, RowType type)) { + ImmutableList.Builder items = ImmutableList.builder(); + for (int i = 0; i < expressions.size(); i++) { + Expression fieldValue = expressions.get(i); + Type fieldType = castToType.getFields().get(i).getType(); + if (!fieldValue.type().equals(fieldType)) { + fieldValue = new Cast(fieldValue, fieldType); } - return new Row(items.build()); + items.add(fieldValue); } + return new Row(items.build(), castToType); } - return treeRewriter.defaultRewrite(node, true); + return treeRewriter.defaultRewrite(node, null); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java index a1144fa87f2a..5489ea16ccdc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java @@ -93,7 +93,7 @@ public Result apply(CorrelatedJoinNode parent, Captures captures, Context contex .putIdentities(parent.getInput().getOutputSymbols()); forEachPair( values.getOutputSymbols().stream(), - row.items().stream(), + row.children().stream(), assignments::put); return Result.ofPlanNode(projectNode(parent.getInput(), assignments.build(), context)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java index 6978e8504df3..d73326394cb7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java @@ -52,7 +52,6 @@ import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Parameter; import io.trino.sql.tree.Query; -import io.trino.sql.tree.Row; import io.trino.sql.tree.SelectItem; import io.trino.sql.tree.ShowStats; import io.trino.sql.tree.Statement; @@ -74,6 +73,7 @@ import static io.trino.spi.type.TimeZoneKey.UTC_KEY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.QueryUtil.aliased; +import static io.trino.sql.QueryUtil.row; import static io.trino.sql.QueryUtil.selectAll; import static io.trino.sql.QueryUtil.selectList; import static io.trino.sql.QueryUtil.simpleQuery; @@ -176,26 +176,24 @@ private Node rewriteShowStats(Plan plan, PlanNodeStatsEstimate planNodeStatsEsti String columnName = root.getColumnNames().get(columnIndex); Type columnType = outputSymbol.type(); SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(outputSymbol); - ImmutableList.Builder rowValues = ImmutableList.builder(); - rowValues.add(new StringLiteral(columnName)); - rowValues.add(toDoubleLiteral(symbolStatistics.getAverageRowSize() * planNodeStatsEstimate.getOutputRowCount() * (1 - symbolStatistics.getNullsFraction()))); - rowValues.add(toDoubleLiteral(symbolStatistics.getDistinctValuesCount())); - rowValues.add(toDoubleLiteral(symbolStatistics.getNullsFraction())); - rowValues.add(NULL_DOUBLE); - rowValues.add(toStringLiteral(columnType, symbolStatistics.getLowValue())); - rowValues.add(toStringLiteral(columnType, symbolStatistics.getHighValue())); - rowsBuilder.add(new Row(rowValues.build())); + rowsBuilder.add(row( + new StringLiteral(columnName), + toDoubleLiteral(symbolStatistics.getAverageRowSize() * planNodeStatsEstimate.getOutputRowCount() * (1 - symbolStatistics.getNullsFraction())), + toDoubleLiteral(symbolStatistics.getDistinctValuesCount()), + toDoubleLiteral(symbolStatistics.getNullsFraction()), + NULL_DOUBLE, + toStringLiteral(columnType, symbolStatistics.getLowValue()), + toStringLiteral(columnType, symbolStatistics.getHighValue()))); } // Stats for whole table - ImmutableList.Builder rowValues = ImmutableList.builder(); - rowValues.add(NULL_VARCHAR); - rowValues.add(NULL_DOUBLE); - rowValues.add(NULL_DOUBLE); - rowValues.add(NULL_DOUBLE); - rowValues.add(toDoubleLiteral(planNodeStatsEstimate.getOutputRowCount())); - rowValues.add(NULL_VARCHAR); - rowValues.add(NULL_VARCHAR); - rowsBuilder.add(new Row(rowValues.build())); + rowsBuilder.add(row( + NULL_VARCHAR, + NULL_DOUBLE, + NULL_DOUBLE, + NULL_DOUBLE, + toDoubleLiteral(planNodeStatsEstimate.getOutputRowCount()), + NULL_VARCHAR, + NULL_VARCHAR)); List resultRows = rowsBuilder.build(); return simpleQuery(selectAll(selectItems), aliased(new Values(resultRows), "table_stats", statsColumnNames)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java index a9d36657f5b9..75b74af43994 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java @@ -53,7 +53,7 @@ protected Expression rewriteExpression(Expression node, Void context, Expression public Expression rewriteRow(Row node, Void context, ExpressionTreeRewriter treeRewriter) { // rewrite Row items to preserve Row structure of ValuesNode - return new Row(node.items().stream().map(item -> new Constant(INTEGER, 0L)).collect(toImmutableList())); + return new Row(node.items().stream().map(item -> new Constant(INTEGER, 0L)).collect(toImmutableList()), node.type()); } }, expression)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java index 059fa2d316d2..e895e4ef7087 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.java @@ -50,18 +50,24 @@ public void test() new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT)))); test( new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")))), anonymousRow(BIGINT, VARCHAR)), - new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT), new Cast(new Constant(VARCHAR, Slices.utf8Slice("a")), VARCHAR)))); + new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), BIGINT), new Constant(VARCHAR, Slices.utf8Slice("a"))))); test( new Cast(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")))), anonymousRow(SMALLINT, VARCHAR)), anonymousRow(BIGINT, VARCHAR)), - new Row(ImmutableList.of(new Cast(new Cast(new Constant(INTEGER, 1L), SMALLINT), BIGINT), new Cast(new Cast(new Constant(VARCHAR, Slices.utf8Slice("a")), VARCHAR), VARCHAR)))); + new Row(ImmutableList.of(new Cast(new Cast(new Constant(INTEGER, 1L), SMALLINT), BIGINT), new Constant(VARCHAR, Slices.utf8Slice("a"))))); // named fields in top-level cast preserved test( new Cast(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")))), anonymousRow(SMALLINT, VARCHAR)), rowType(field("x", BIGINT), field(VARCHAR))), - new Cast(new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), SMALLINT), new Cast(new Constant(VARCHAR, Slices.utf8Slice("a")), VARCHAR))), rowType(field("x", BIGINT), field(VARCHAR)))); + new Row(ImmutableList.of(new Cast(new Cast(new Constant(INTEGER, 1L), SMALLINT), BIGINT), new Constant(VARCHAR, Slices.utf8Slice("a"))), rowType(field("x", BIGINT), field(VARCHAR)))); test( new Cast(new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")))), rowType(field("a", SMALLINT), field("b", VARCHAR))), rowType(field("x", BIGINT), field(VARCHAR))), - new Cast(new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), SMALLINT), new Cast(new Constant(VARCHAR, Slices.utf8Slice("a")), VARCHAR))), rowType(field("x", BIGINT), field(VARCHAR)))); + new Row(ImmutableList.of(new Cast(new Cast(new Constant(INTEGER, 1L), SMALLINT), BIGINT), new Constant(VARCHAR, Slices.utf8Slice("a"))), rowType(field("x", BIGINT), field(VARCHAR)))); + test( + new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a")))), rowType(field("x", SMALLINT), field(VARCHAR))), + new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), SMALLINT), new Constant(VARCHAR, Slices.utf8Slice("a"))), rowType(field("x", SMALLINT), field(VARCHAR)))); + test( + new Cast(new Row(ImmutableList.of(new Constant(INTEGER, 1L), new Constant(VARCHAR, Slices.utf8Slice("a"))), rowType(field("a", SMALLINT), field("b", VARCHAR))), rowType(field("x", SMALLINT), field(VARCHAR))), + new Row(ImmutableList.of(new Cast(new Constant(INTEGER, 1L), SMALLINT), new Constant(VARCHAR, Slices.utf8Slice("a"))), rowType(field("x", SMALLINT), field(VARCHAR)))); // expression nested in another unrelated expression test( diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestGroupBy.java b/core/trino-main/src/test/java/io/trino/sql/query/TestGroupBy.java index 3cfe153ecf0e..326c2bc8e2ac 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestGroupBy.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestGroupBy.java @@ -70,7 +70,7 @@ public void testCastDifferentCase() "SELECT CAST(row(x) AS row(\"a\" bigint)) " + "FROM (VALUES 42) t(x) " + "GROUP BY CAST(row(x) AS row(\"A\" bigint))")) - .failure().hasMessage("line 1:8: 'CAST(ROW (x) AS ROW(\"a\" bigint))' must be an aggregate expression or appear in GROUP BY clause"); + .failure().hasMessage("line 1:8: 'CAST(ROW(x) AS ROW(\"a\" bigint))' must be an aggregate expression or appear in GROUP BY clause"); } @Test diff --git a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index c1f283a1f548..4a01b3d31238 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -152,9 +152,21 @@ protected String visitNode(Node node, Void context) @Override protected String visitRow(Row node, Void context) { - return node.getItems().stream() + return node.getFields().stream() .map(child -> process(child, context)) - .collect(joining(", ", "ROW (", ")")); + .collect(joining(", ", "ROW(", ")")); + } + + @Override + protected String visitRowField(Row.Field node, Void context) + { + StringBuilder builder = new StringBuilder(); + builder.append(process(node.getExpression(), context)); + if (node.getName().isPresent()) { + builder.append(" AS "); + builder.append(process(node.getName().get(), context)); + } + return builder.toString(); } @Override diff --git a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java index b31276c42e0e..4f85dcffd689 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java +++ b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java @@ -43,6 +43,7 @@ import io.trino.sql.tree.Values; import io.trino.sql.tree.WindowDefinition; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -141,7 +142,7 @@ public static Values values(Row... row) public static Row row(Expression... values) { - return new Row(ImmutableList.copyOf(values)); + return new Row(Arrays.stream(values).map(Row.Field::new).collect(Collectors.toList())); } public static Relation aliased(Relation relation, String alias) diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 4b860d5eb9c2..87d9465a642c 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -2018,17 +2018,25 @@ protected Void visitRow(Row node, Integer indent) { builder.append("ROW("); boolean firstItem = true; - for (Expression item : node.getItems()) { + for (Row.Field field : node.getFields()) { if (!firstItem) { builder.append(", "); } - process(item, indent); + process(field, indent); firstItem = false; } builder.append(")"); return null; } + @Override + protected Void visitRowField(Row.Field node, Integer context) + { + builder.append(formatExpression(node.getExpression())); + node.getName().ifPresent(name -> builder.append(" AS ").append(formatName(name))); + return null; + } + @Override protected Void visitStartTransaction(StartTransaction node, Integer indent) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 06322c2c9234..8fdbf0f5b789 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -2374,7 +2374,20 @@ public Node visitParenthesizedExpression(SqlBaseParser.ParenthesizedExpressionCo @Override public Node visitRowConstructor(SqlBaseParser.RowConstructorContext context) { - return new Row(getLocation(context), visit(context.expression(), Expression.class)); + if (context.fieldConstructor().isEmpty()) { + return new Row(getLocation(context), visit(context.expression(), Expression.class).stream() + .map(expression -> new Row.Field(expression.getLocation().orElseThrow(), Optional.empty(), expression)) + .toList()); + } + return new Row(getLocation(context), visit(context.fieldConstructor(), Row.Field.class)); + } + + @Override + public Node visitFieldConstructor(SqlBaseParser.FieldConstructorContext context) + { + return new Row.Field(getLocation(context), + visitIfPresent(context.identifier(), Identifier.class), + (Expression) visit(context.expression())); } @Override diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index 981bded97775..9b56bc1497e5 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -497,6 +497,11 @@ protected R visitRow(Row node, C context) return visitExpression(node, context); } + protected R visitRowField(Row.Field node, C context) + { + return visitNode(node, context); + } + protected R visitTableSubquery(TableSubquery node, C context) { return visitQueryBody(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java index 4ad7f39c4b83..01d4dc298ae2 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java @@ -534,9 +534,19 @@ protected Void visitValues(Values node, C context) @Override protected Void visitRow(Row node, C context) { - for (Expression expression : node.getItems()) { - process(expression, context); + for (Row.Field field : node.getFields()) { + process(field, context); + } + return null; + } + + @Override + protected Void visitRowField(Row.Field node, C context) + { + if (node.getName().isPresent()) { + process(node.getName().get(), context); } + process(node.getExpression(), context); return null; } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java index c08a5ad033dc..81e6563a28c9 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java @@ -88,10 +88,14 @@ protected Expression visitRow(Row node, Context context) } } - List items = rewrite(node.getItems(), context); + ImmutableList.Builder builder = ImmutableList.builder(); + for (Row.Field field : node.getFields()) { + builder.add(new Row.Field(field.getLocation().orElseThrow(), field.getName(), rewrite(field.getExpression(), context.get()))); + } + List fields = builder.build(); - if (!sameElements(node.getItems(), items)) { - return new Row(node.getLocation().orElseThrow(), items); + if (!sameElements(node.getFields(), fields)) { + return new Row(node.getLocation().orElseThrow(), fields); } return node; diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Row.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Row.java index cf49ca6ad578..e8ff062ccda2 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Row.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Row.java @@ -19,27 +19,29 @@ import java.util.Objects; import java.util.Optional; +import static java.util.Objects.requireNonNull; + public final class Row extends Expression { - private final List items; + private final List fields; @Deprecated - public Row(List items) + public Row(List fields) { super(Optional.empty()); - this.items = ImmutableList.copyOf(items); + this.fields = ImmutableList.copyOf(fields); } - public Row(NodeLocation location, List items) + public Row(NodeLocation location, List fields) { super(location); - this.items = ImmutableList.copyOf(items); + this.fields = ImmutableList.copyOf(fields); } - public List getItems() + public List getFields() { - return items; + return fields; } @Override @@ -51,13 +53,13 @@ public R accept(AstVisitor visitor, C context) @Override public List getChildren() { - return items; + return fields; } @Override public int hashCode() { - return Objects.hash(items); + return Objects.hash(fields); } @Override @@ -70,7 +72,7 @@ public boolean equals(Object obj) return false; } Row other = (Row) obj; - return Objects.equals(this.items, other.items); + return Objects.equals(this.fields, other.fields); } @Override @@ -78,4 +80,93 @@ public boolean shallowEquals(Node other) { return sameClass(this, other); } + + public static class Field + extends Node + { + private final Optional name; + private final Expression expression; + + public Field(NodeLocation location, Optional name, Expression expression) + { + super(location); + + this.name = requireNonNull(name, "name is null"); + this.expression = requireNonNull(expression, "expression is null"); + } + + @Deprecated + public Field(Expression expression) + { + super(Optional.empty()); + + this.name = Optional.empty(); + this.expression = requireNonNull(expression, "expression is null"); + } + + public Optional getName() + { + return name; + } + + public Expression getExpression() + { + return expression; + } + + @Override + public List getChildren() + { + ImmutableList.Builder children = ImmutableList.builder(); + name.ifPresent(children::add); + children.add(expression); + + return children.build(); + } + + @Override + protected R accept(AstVisitor visitor, C context) + { + return visitor.visitRowField(this, context); + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + if (name.isPresent()) { + builder.append(name.get()); + builder.append(" "); + } + builder.append(expression); + + return builder.toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Row.Field field = (Row.Field) o; + return name.equals(field.name) && + expression.equals(field.expression); + } + + @Override + public int hashCode() + { + return Objects.hash(name, expression); + } + + @Override + public boolean shallowEquals(Node other) + { + return sameClass(this, other); + } + } } diff --git a/core/trino-parser/src/test/java/io/trino/sql/TestExpressionFormatter.java b/core/trino-parser/src/test/java/io/trino/sql/TestExpressionFormatter.java index be452f97adb7..8518dcc12623 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/TestExpressionFormatter.java +++ b/core/trino-parser/src/test/java/io/trino/sql/TestExpressionFormatter.java @@ -13,10 +13,13 @@ */ package io.trino.sql; +import com.google.common.collect.ImmutableList; import io.trino.sql.tree.Expression; import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.IntervalLiteral; +import io.trino.sql.tree.NodeLocation; +import io.trino.sql.tree.Row; import io.trino.sql.tree.StringLiteral; import org.junit.jupiter.api.Test; @@ -110,6 +113,40 @@ public void testIntervalLiteral() "INTERVAL -'2' HOUR TO SECOND"); } + @Test + public void testRowLiteral() + { + assertFormattedExpression( + createRow("a"), + "ROW('v0' AS a)"); + assertFormattedExpression( + createRow((String) null), + "ROW('v0')"); + assertFormattedExpression( + createRow("a", null, "b"), + "ROW('v0' AS a, 'v1', 'v2' AS b)"); + assertFormattedExpression( + new Row(ImmutableList.of(createRowField("x", new Row(ImmutableList.of(createRowField("y", createRow("a", null, "b"))))))), + "ROW(ROW(ROW('v0' AS a, 'v1', 'v2' AS b) AS y) AS x)"); + assertFormattedExpression( + new Row(ImmutableList.of(createRowField(null, new Row(ImmutableList.of(createRowField(null, createRow("a", null, "b"))))))), + "ROW(ROW(ROW('v0' AS a, 'v1', 'v2' AS b)))"); + } + + private static Row createRow(String... fieldNames) + { + ImmutableList.Builder fields = ImmutableList.builder(); + for (int i = 0; i < fieldNames.length; i++) { + fields.add(createRowField(fieldNames[i], new StringLiteral("v" + i))); + } + return new Row(fields.build()); + } + + private static Row.Field createRowField(String fieldName, Expression expression) + { + return new Row.Field(new NodeLocation(1, 1), Optional.ofNullable(fieldName).map(Identifier::new), expression); + } + private void assertFormattedExpression(Expression expression, String expected) { assertThat(ExpressionFormatter.formatExpression(expression)).as("formatted expression") diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index b3c35ef28ab0..7bc66ca5fdc8 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -605,9 +605,9 @@ public void testRowSubscript() new Row( location(1, 1), ImmutableList.of( - new LongLiteral(location(1, 6), "1"), - new StringLiteral(location(1, 9), "a"), - new BooleanLiteral(location(1, 14), "true"))), + new Row.Field(location(1, 6), Optional.empty(), new LongLiteral(location(1, 6), "1")), + new Row.Field(location(1, 9), Optional.empty(), new StringLiteral(location(1, 9), "a")), + new Row.Field(location(1, 14), Optional.empty(), new BooleanLiteral(location(1, 14), "true")))), new LongLiteral(location(1, 20), "1"))); } @@ -640,7 +640,7 @@ public void testAllColumns() ImmutableList.of( new AllColumns( Optional.empty(), - Optional.of(new Row(ImmutableList.of(new LongLiteral("1"), new StringLiteral("a"), new BooleanLiteral("true")))), + Optional.of(row(new LongLiteral("1"), new StringLiteral("a"), new BooleanLiteral("true"))), ImmutableList.of()))))); assertStatement("SELECT ROW (1, 'a', true).* AS (f1, f2, f3)", simpleQuery( @@ -649,7 +649,7 @@ public void testAllColumns() ImmutableList.of( new AllColumns( Optional.empty(), - Optional.of(new Row(ImmutableList.of(new LongLiteral("1"), new StringLiteral("a"), new BooleanLiteral("true")))), + Optional.of(row(new LongLiteral("1"), new StringLiteral("a"), new BooleanLiteral("true"))), ImmutableList.of(new Identifier("f1"), new Identifier("f2"), new Identifier("f3"))))))); } @@ -1627,7 +1627,7 @@ public void testSelectWithRowType() selectList( new DereferenceExpression( new Cast( - new Row(Lists.newArrayList(new LongLiteral("11"), new LongLiteral("12"))), + row(new LongLiteral("11"), new LongLiteral("12")), rowType(location(1, 26), field(location(1, 30), "COL0", simpleType(location(1, 35), "INTEGER")), field(location(1, 44), "COL1", simpleType(location(1, 49), "INTEGER")))), @@ -4952,8 +4952,8 @@ public void testQuantifiedComparison() ImmutableList.of(), Optional.empty(), new Values(location(1, 13), ImmutableList.of( - new Row(location(1, 20), ImmutableList.of(new LongLiteral(location(1, 24), "1"))), - new Row(location(1, 28), ImmutableList.of(new LongLiteral(location(1, 32), "2"))))), + new Row(location(1, 20), ImmutableList.of(new Row.Field(location(1, 24), Optional.empty(), new LongLiteral(location(1, 24), "1")))), + new Row(location(1, 28), ImmutableList.of(new Row.Field(location(1, 32), Optional.empty(), new LongLiteral(location(1, 32), "2")))))), Optional.empty(), Optional.empty(), Optional.empty())))); diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index 41940b9368d5..23c0e4b75550 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -159,8 +159,8 @@ private static Stream statements() "line 1:52: mismatched input ''. Expecting: "), Arguments.of("SELECT * FROM t FOR VERSION AS OF TIMESTAMP WHERE", "line 1:50: mismatched input ''. Expecting: "), - Arguments.of("SELECT ROW(DATE '2022-10-10', DOUBLE 12.0)", - "line 1:38: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'ORDER', 'OVER', 'PRECISION', '[', '||', , "), + Arguments.of("SELECT (DATE '2022-10-10', DOUBLE 12.0)", + "line 1:35: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'OVER', 'PRECISION', '[', '||', , "), Arguments.of("VALUES(DATE 2)", "line 1:13: mismatched input '2'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'OVER', '[', '||', , "), Arguments.of("SELECT count(DISTINCT *) FROM (VALUES 1)", diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 1f7378f58d05..1662c3e5e97e 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -1772,8 +1772,8 @@ public void testShowCreateView() "FROM\n" + " (\n" + " VALUES \n" + - " ROW (1, 'one')\n" + - " , ROW (2, 't')\n" + + " ROW(1, 'one')\n" + + " , ROW(2, 't')\n" + ") t (col1, col2)", getSession().getCatalog().get(), getSession().getSchema().get(),