Skip to content

Commit

Permalink
Allow field name declaration in row literal
Browse files Browse the repository at this point in the history
Add support for `row(a 1, b 2)` instead of the much more complex
`cast(row(1, 2) as row(a integer, b integer))`.
  • Loading branch information
dain committed Mar 10, 2025
1 parent 0b565ee commit 6cbfff3
Show file tree
Hide file tree
Showing 71 changed files with 491 additions and 274 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ primaryExpression
| QUESTION_MARK #parameter
| POSITION '(' valueExpression IN valueExpression ')' #position
| '(' expression (',' expression)+ ')' #rowConstructor
| ROW '(' expression (',' expression)* ')' #rowConstructor
| ROW '(' fieldConstructor (',' fieldConstructor)* ')' #rowConstructor
| name=LISTAGG '(' setQuantifier? expression (',' string)?
(ON OVERFLOW listAggOverflowBehavior)? ')'
(WITHIN GROUP '(' ORDER BY sortItem (',' sortItem)* ')')
Expand Down Expand Up @@ -646,6 +646,13 @@ primaryExpression
')' #jsonArray
;

// Match a single expression before identifier expression so that the field `ROW(1, 2)` is not parsed as
// "ROW" identifier followed by the expression `(1, 2)`.
fieldConstructor
: expression
| identifier expression
;

jsonPathInvocation
: jsonValueExpression ',' path=string
(AS pathName=identifier)?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private List<Object> getSymbolValues(ValuesNode valuesNode, int symbolId, Type r
checkState(valuesNode.getRows().isPresent(), "rows is empty");
return valuesNode.getRows().get().stream()
.map(row -> switch (row) {
case Row value -> ((Constant) value.items().get(symbolId)).value();
case Row value -> ((Constant) value.fields().get(symbolId).value()).value();
case Constant(Type type, SqlRow value) -> readNativeValue(symbolType, value.getRawFieldBlock(symbolId), value.getRawIndex());
default -> throw new IllegalArgumentException("Expected Row or Constant: " + row);
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,10 +701,16 @@ protected Boolean visitTryExpression(TryExpression node, Void context)
@Override
protected Boolean visitRow(Row node, Void context)
{
return node.getItems().stream()
return node.getFields().stream()
.allMatch(item -> process(item, context));
}

@Override
protected Boolean visitRowField(Row.Field node, Void context)
{
return process(node.getExpression(), context);
}

@Override
protected Boolean visitParameter(Parameter node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,11 @@ public Type process(Node node, @Nullable Context context)
@Override
protected Type visitRow(Row node, Context context)
{
List<Type> types = node.getItems().stream()
.map(child -> process(child, context))
List<RowType.Field> fields = node.getFields().stream()
.map(field -> new RowType.Field(field.getName().map(Identifier::getCanonicalValue), process(field.getExpression(), context)))
.collect(toImmutableList());

Type type = RowType.anonymous(types);
Type type = RowType.from(fields);
return setExpressionType(node, type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3929,7 +3929,7 @@ protected Scope visitValues(Values node, Optional<Scope> scope)
// TODO coerce the whole Row and add an Optimizer rule that converts CAST(ROW(...) AS ...) into ROW(CAST(...), CAST(...), ...).
// The rule would also handle Row-type expressions that were specified as CAST(ROW). It should support multiple casts over a ROW.
for (int i = 0; i < actualType.getTypeParameters().size(); i++) {
Expression item = ((Row) row).getItems().get(i);
Expression item = ((Row) row).getFields().get(i).getExpression();
Type actualItemType = actualType.getTypeParameters().get(i);
Type expectedItemType = commonSuperType.getTypeParameters().get(i);
if (!actualItemType.equals(expectedItemType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ protected Void visitLogical(Logical node, C context)
@Override
protected Void visitRow(Row node, C context)
{
for (Expression expression : node.items()) {
process(expression, context);
for (Row.Field field : node.fields()) {
process(field.value(), context);
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ protected String visitArray(Array node, Void context)
@Override
protected String visitRow(Row node, Void context)
{
return node.items().stream()
.map(child -> process(child, context))
return node.fields().stream()
.map(this::formatRowField)
.collect(joining(", ", "ROW (", ")"));
}

private String formatRowField(Row.Field field)
{
return field.name().map(name -> name + " ").orElse("") + process(field.value(), null);
}

@Override
protected String visitExpression(Expression node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,13 @@ protected Expression visitRow(Row node, Context<C> context)
}
}

List<Expression> items = rewrite(node.items(), context);
ImmutableList.Builder<Row.Field> builder = ImmutableList.builder();
for (Row.Field field : node.fields()) {
builder.add(new Row.Field(field.name(), rewrite(field.value(), context.get())));
}
List<Row.Field> items = builder.build();

if (!sameElements(node.items(), items)) {
if (!sameElements(node.fields(), items)) {
return new Row(items);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public static boolean mayFail(PlannerContext plannerContext, Expression expressi
case Logical e -> e.terms().stream().anyMatch(argument -> mayFail(plannerContext, argument));
case NullIf e -> mayFail(plannerContext, e.first()) || mayFail(plannerContext, e.second());
case Reference e -> false;
case Row e -> e.items().stream().anyMatch(argument -> mayFail(plannerContext, argument));
case Row e -> e.fields().stream().anyMatch(field -> mayFail(plannerContext, field.value()));
case Switch e -> mayFail(plannerContext, e.operand()) || e.whenClauses().stream().anyMatch(clause -> mayFail(plannerContext, clause.getOperand()) || mayFail(plannerContext, clause.getResult())) ||
mayFail(plannerContext, e.defaultValue());
};
Expand Down
50 changes: 43 additions & 7 deletions core/trino-main/src/main/java/io/trino/sql/ir/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,32 @@
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;

@JsonSerialize
public record Row(List<Expression> items)
public record Row(List<Field> fields)
implements Expression
{
public Row
{
requireNonNull(items, "items is null");
items = ImmutableList.copyOf(items);
requireNonNull(fields, "fields is null");
fields = ImmutableList.copyOf(fields);
}

public static Row anonymousRow(List<Expression> values)
{
return new Row(values.stream()
.map(Field::anonymousField)
.collect(Collectors.toList()));
}

@Override
public Type type()
{
return RowType.anonymous(items.stream().map(Expression::type).collect(Collectors.toList()));
return RowType.from(fields.stream().map(Field::asRowTypeField).collect(Collectors.toList()));
}

@Override
Expand All @@ -48,16 +56,44 @@ public <R, C> R accept(IrVisitor<R, C> visitor, C context)
@Override
public List<? extends Expression> children()
{
return items;
return fields.stream()
.map(Field::value)
.collect(Collectors.toList());
}

@Override
public String toString()
{
return "(" +
items.stream()
.map(Expression::toString)
fields.stream()
.map(Field::toString)
.collect(Collectors.joining(", ")) +
")";
}

@JsonSerialize
public record Field(Optional<String> name, Expression value)
{
public Field
{
requireNonNull(name, "name is null");
requireNonNull(value, "value is null");
}

public static Field anonymousField(Expression value)
{
return new Field(Optional.empty(), value);
}

public RowType.Field asRowTypeField()
{
return new RowType.Field(name, value.type());
}

@Override
public String toString()
{
return name.map(n -> n + " " + value).orElseGet(value::toString);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ private Object evaluateInternal(Switch expression, Session session, Map<String,
private Object evaluateInternal(Row expression, Session session, Map<String, Object> bindings)
{
return buildRowValue((RowType) expression.type(), builders -> {
for (int i = 0; i < expression.items().size(); ++i) {
for (int i = 0; i < expression.fields().size(); ++i) {
Expression fieldValue = expression.fields().get(i).value();
writeNativeValue(
expression.items().get(i).type(), builders.get(i),
evaluate(expression.items().get(i), session, bindings));
fieldValue.type(), builders.get(i),
evaluate(fieldValue, session, bindings));
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,17 @@ private Optional<Expression> processChildren(Expression expression, Session sess
case Logical logical -> process(logical.terms(), session, bindings).map(arguments -> new Logical(logical.operator(), arguments));
case Call call -> process(call.arguments(), session, bindings).map(arguments -> new Call(call.function(), arguments));
case Array array -> process(array.elements(), session, bindings).map(elements -> new Array(array.elementType(), elements));
case Row row -> process(row.items(), session, bindings).map(fields -> new Row(fields));
case Row row -> {
boolean changed = false;
ImmutableList.Builder<Row.Field> fields = ImmutableList.builder();
for (Row.Field field : row.fields()) {
Optional<Expression> optimized = process(field.value(), session, bindings);
changed = changed || optimized.isPresent();
fields.add(new Row.Field(field.name(), optimized.orElse(field.value())));
}

yield changed ? Optional.of(new Row(fields.build())) : Optional.empty();
}
case Between between -> {
Optional<Expression> value = process(between.value(), session, bindings);
Optional<Expression> min = process(between.min(), session, bindings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class EvaluateFieldReference
public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> bindings)
{
return switch (expression) {
case FieldReference(Row row, int field) -> Optional.of(row.items().get(field));
case FieldReference(Row row, int field) -> Optional.of(row.fields().get(field).value());
case FieldReference(Constant(RowType type, SqlRow row), int field) -> {
Type fieldType = type.getFields().get(field).getType();
yield Optional.of(new Constant(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class EvaluateRow
@Override
public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> bindings)
{
if (!(expression instanceof Row(List<Expression> fields)) || !fields.stream().allMatch(Constant.class::isInstance)) {
if (!(expression instanceof Row(List<Row.Field> fields)) || !fields.stream().map(Row.Field::value).allMatch(Constant.class::isInstance)) {
return Optional.empty();
}

Expand All @@ -46,7 +46,8 @@ public Optional<Expression> apply(Expression expression, Session session, Map<Sy
rowType,
buildRowValue(rowType, builders -> {
for (int i = 0; i < fields.size(); ++i) {
writeNativeValue(fields.get(i).type(), builders.get(i), ((Constant) fields.get(i)).value());
Expression fieldValue = fields.get(i).value();
writeNativeValue(fieldValue.type(), builders.get(i), ((Constant) fieldValue).value());
}
})));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ public Expression visitValues(ValuesNode node, Void context)
for (Expression expression : node.getRows().get()) {
if (expression instanceof Row row) {
for (int i = 0; i < node.getOutputSymbols().size(); i++) {
Expression value = row.items().get(i);
Expression value = row.fields().get(i).value();
if (!DeterminismEvaluator.isDeterministic(value)) {
nonDeterministic[i] = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ public PlanNode planStatement(Analysis analysis, Statement statement)
if ((statement instanceof CreateTableAsSelect && analysis.getCreate().orElseThrow().isCreateTableAsSelectNoOp()) ||
statement instanceof RefreshMaterializedView && analysis.isSkipMaterializedViewRefresh()) {
Symbol symbol = symbolAllocator.newSymbol("rows", BIGINT);
PlanNode source = new ValuesNode(idAllocator.getNextId(), ImmutableList.of(symbol), ImmutableList.of(new Row(ImmutableList.of(new Constant(BIGINT, 0L)))));
PlanNode source = new ValuesNode(idAllocator.getNextId(), ImmutableList.of(symbol), ImmutableList.of(Row.anonymousRow(ImmutableList.of(new Constant(BIGINT, 0L)))));
return new OutputNode(idAllocator.getNextId(), source, ImmutableList.of("rows"), ImmutableList.of(symbol));
}
return createOutputPlan(planStatementWithoutOutput(analysis, statement), analysis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ public PlanNode plan(Update node)
rowBuilder.add(new Constant(INTEGER, 0L));

// Finally, the merge row is complete
Expression mergeRow = new Row(rowBuilder.build());
Expression mergeRow = Row.anonymousRow(rowBuilder.build());

List<io.trino.sql.tree.Expression> constraints = analysis.getCheckConstraints(table);
if (!constraints.isEmpty()) {
Expand Down Expand Up @@ -850,7 +850,7 @@ public MergeWriterNode plan(Merge merge)
coerceIfNecessary(analysis, casePredicate.get(), subPlan.rewrite(casePredicate.get())));
}

whenClauses.add(new WhenClause(condition, new Row(rowBuilder.build())));
whenClauses.add(new WhenClause(condition, Row.anonymousRow(rowBuilder.build())));

List<io.trino.sql.tree.Expression> constraints = analysis.getCheckConstraints(mergeAnalysis.getTargetTable());
if (!constraints.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1765,15 +1765,17 @@ protected RelationPlan visitValues(Values node, Void context)
ImmutableList.Builder<Expression> rows = ImmutableList.builder();
for (io.trino.sql.tree.Expression row : node.getRows()) {
if (row instanceof io.trino.sql.tree.Row) {
rows.add(new Row(((io.trino.sql.tree.Row) row).getItems().stream()
.map(item -> coerceIfNecessary(analysis, item, translationMap.rewrite(item)))
rows.add(new Row(((io.trino.sql.tree.Row) row).getFields().stream()
.map(field -> new Row.Field(
field.getName().map(Identifier::getCanonicalValue),
coerceIfNecessary(analysis, field.getExpression(), translationMap.rewrite(field.getExpression()))))
.collect(toImmutableList())));
}
else if (analysis.getType(row) instanceof RowType) {
rows.add(coerceIfNecessary(analysis, row, translationMap.rewrite(row)));
}
else {
rows.add(new Row(ImmutableList.of(coerceIfNecessary(analysis, row, translationMap.rewrite(row)))));
rows.add(Row.anonymousRow(ImmutableList.of(coerceIfNecessary(analysis, row, translationMap.rewrite(row)))));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ private PlanBuilder planScalarSubquery(PlanBuilder subPlan, Cluster<SubqueryExpr
}
}

Expression expression = new Cast(new Row(fields.build()), type);
Expression expression = new Cast(Row.anonymousRow(fields.build()), type);

root = new ProjectNode(idAllocator.getNextId(), root, Assignments.of(column, expression));
}
Expand Down Expand Up @@ -442,7 +442,7 @@ private PlanAndMappings planValue(PlanBuilder subPlan, io.trino.sql.tree.Express

Assignments assignments = Assignments.builder()
.putIdentities(subPlan.getRoot().getOutputSymbols())
.put(wrapped, new Row(ImmutableList.of(column.toSymbolReference())))
.put(wrapped, Row.anonymousRow(ImmutableList.of(column.toSymbolReference())))
.build();

subPlan = subPlan.withNewRoot(new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments));
Expand Down Expand Up @@ -481,7 +481,7 @@ private PlanAndMappings planSubquery(io.trino.sql.tree.Expression subquery, Opti
new ProjectNode(
idAllocator.getNextId(),
relationPlan.getRoot(),
Assignments.of(column, new Cast(new Row(fields.build()), type))));
Assignments.of(column, new Cast(Row.anonymousRow(fields.build()), type))));

return coerceIfNecessary(subqueryPlan, column, subquery, coercion);
}
Expand Down
Loading

0 comments on commit 6cbfff3

Please sign in to comment.