Skip to content

Commit

Permalink
Add support for CORRESPONDING option in set operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Mar 9, 2025
1 parent 58ecb8f commit 3bb218f
Show file tree
Hide file tree
Showing 13 changed files with 337 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ rowCount

queryTerm
: queryPrimary #queryTermDefault
| left=queryTerm operator=INTERSECT setQuantifier? right=queryTerm #setOperation
| left=queryTerm operator=(UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation
| left=queryTerm operator=INTERSECT
setQuantifier? CORRESPONDING? right=queryTerm #setOperation
| left=queryTerm operator=(UNION | EXCEPT)
setQuantifier? CORRESPONDING? right=queryTerm #setOperation
;

queryPrimary
Expand Down Expand Up @@ -1000,7 +1002,7 @@ nonReserved
// IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved
: ABSENT | ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT | AUTHORIZATION
| BEGIN | BERNOULLI | BOTH
| CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | COUNT | CURRENT
| CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | CORRESPONDING | COUNT | CURRENT
| DATA | DATE | DAY | DECLARE | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DETERMINISTIC | DISTRIBUTED | DO | DOUBLE
| ELSEIF | EMPTY | ENCODING | ERROR | EXCLUDING | EXECUTE | EXPLAIN
| FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTION | FUNCTIONS
Expand Down Expand Up @@ -1061,6 +1063,7 @@ CONDITIONAL: 'CONDITIONAL';
CONSTRAINT: 'CONSTRAINT';
COUNT: 'COUNT';
COPARTITION: 'COPARTITION';
CORRESPONDING: 'CORRESPONDING';
CREATE: 'CREATE';
CROSS: 'CROSS';
CUBE: 'CUBE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.toOptional;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -81,6 +82,16 @@ public Field getFieldByIndex(int fieldIndex)
return allFields.get(fieldIndex);
}

/**
* Gets the field at the specified name.
*/
public Optional<Field> getFieldByName(String name)
{
return allFields.stream()
.filter(field -> field.getName().isPresent() && field.getName().get().equals(name))
.collect(toOptional());
}

/**
* Gets only the visible fields.
* No assumptions should be made about the order of the fields returned from this method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3189,11 +3189,29 @@ protected Scope visitSubqueryExpression(SubqueryExpression node, Optional<Scope>
@Override
protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
{
checkState(node.getRelations().size() >= 2);

List<RelationType> childrenTypes = node.getRelations().stream()
.map(relation -> process(relation, scope).getRelationType().withOnlyVisibleFields())
.collect(toImmutableList());
List<Relation> relations = node.getRelations();
checkState(relations.size() == 2, "relation size must be 2");
boolean corresponding = node.isCorresponding();

List<RelationType> childrenTypes = new ArrayList<>();
childrenTypes.add(process(relations.getFirst(), scope).getRelationType().withOnlyVisibleFields());
if (corresponding) {
RelationType left = childrenTypes.getFirst();
RelationType right = process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields();
checkColumnNames(node, left.getVisibleFields());
checkColumnNames(node, right.getVisibleFields());

List<Field> fields = new ArrayList<>();
for (int i = 0; i < left.getAllFieldCount(); i++) {
Field field = left.getFieldByIndex(i);
String name = field.getName().orElseThrow();
fields.add(right.getFieldByName(name).orElseThrow(() -> semanticException(COLUMN_NOT_FOUND, node, "Column '%s' cannot be resolved", name)));
}
childrenTypes.add(new RelationType(fields).withOnlyVisibleFields());
}
else {
childrenTypes.add(process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields());
}

String setOperationName = node.getClass().getSimpleName().toUpperCase(ENGLISH);
Type[] outputFieldTypes = childrenTypes.get(0).getVisibleFields().stream()
Expand Down Expand Up @@ -3264,8 +3282,8 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
.collect(toImmutableSet()));
}

for (int i = 0; i < node.getRelations().size(); i++) {
Relation relation = node.getRelations().get(i);
for (int i = 0; i < relations.size(); i++) {
Relation relation = relations.get(i);
RelationType relationType = childrenTypes.get(i);
for (int j = 0; j < relationType.getVisibleFields().size(); j++) {
Type outputFieldType = outputFieldTypes[j];
Expand All @@ -3279,6 +3297,17 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
return createAndAssignScope(node, scope, outputDescriptorFields);
}

private static void checkColumnNames(SetOperation node, Collection<Field> fields)
{
Set<String> names = new HashSet<>();
for (Field field : fields) {
String name = field.getName().orElseThrow(() -> semanticException(MISSING_COLUMN_NAME, node, "Anonymous columns are not allowed in set operations with CORRESPONDING"));
if (!names.add(name)) {
throw semanticException(AMBIGUOUS_NAME, node, "Duplicate columns found when using CORRESPONDING in set operations: %s", name);
}
}
}

@Override
protected Scope visitJoin(Join node, Optional<Scope> scope)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1865,9 +1865,34 @@ private SetOperationPlan process(SetOperation node)
ImmutableListMultimap.Builder<Symbol, Symbol> symbolMapping = ImmutableListMultimap.builder();
ImmutableList.Builder<PlanNode> sources = ImmutableList.builder();

for (Relation child : node.getRelations()) {
List<Relation> relations = node.getRelations();
checkArgument(relations.size() == 2, "relations size must be 2");
Relation rightRelation = relations.getLast();
for (Relation child : relations) {
RelationPlan plan = process(child, null);

if (node.isCorresponding() && child.equals(rightRelation)) {
// Replace right relation's field order to match the output fields of the set operation
Map<String, Symbol> nameToSymbol = new HashMap<>();
RelationType descriptor = plan.getDescriptor();
Collection<Field> visibleFields = outputFields.getVisibleFields();
for (int i = 0; i < visibleFields.size(); i++) {
nameToSymbol.put(descriptor.getFieldByIndex(i).getName().orElseThrow(), plan.getSymbol(i));
}

ImmutableList.Builder<Symbol> fieldMappingsBuilder = ImmutableList.builderWithExpectedSize(visibleFields.size());
for (Field field : visibleFields) {
String fieldName = field.getName().orElseThrow();
fieldMappingsBuilder.add(nameToSymbol.get(fieldName));
}
List<Symbol> fieldMappings = fieldMappingsBuilder.build();
ProjectNode projectNode = new ProjectNode(
idAllocator.getNextId(),
plan.getRoot(),
Assignments.identity(fieldMappings));
plan = new RelationPlan(projectNode, plan.getScope(), fieldMappings, plan.getOuterContext());
}

NodeAndMappings planAndMappings;
List<Type> types = analysis.getRelationCoercion(child);
if (types == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,116 @@ public void testIntersectWithEmptyBranches()
.describedAs("INTERSECT DISTINCT with empty branches")
.returnsEmptyResult();
}

@Test
void testExceptCorresponding()
{
assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
EXCEPT CORRESPONDING
SELECT * FROM (VALUES ('alice', 1)) t(y, x)
"""))
.returnsEmptyResult();

assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
EXCEPT ALL CORRESPONDING
SELECT * FROM (VALUES ('alice', 1)) t(y, x)
"""))
.matches("VALUES (1, 'alice')");

}

@Test
void testUnionCorresponding()
{
assertThat(assertions.query(
"""
SELECT 1 AS x, 'alice' AS y
UNION CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x)
"""))
.matches("VALUES (1, 'alice'), (2, 'bob')");

assertThat(assertions.query(
"""
SELECT 1 AS x, 'alice' AS y
UNION ALL CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x)
"""))
.matches("VALUES (1, 'alice'), (1, 'alice'), (2, 'bob')");
}

@Test
void testIntersectCorresponding()
{
assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
INTERSECT CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('alice', 1)) t(y, x)
"""))
.matches("VALUES (1, 'alice')");

assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
INTERSECT ALL CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('alice', 1)) t(y, x)
"""))
.matches("VALUES (1, 'alice'), (1, 'alice')");
}

@Test
void testCorrespondingDuplicateNames()
{
assertThat(assertions.query("SELECT 1 AS x, 2 AS y EXCEPT CORRESPONDING SELECT 1 AS x, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
assertThat(assertions.query("SELECT 1 AS x, 2 AS x EXCEPT CORRESPONDING SELECT 1 AS y, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");

assertThat(assertions.query("SELECT 1 AS x, 2 AS y UNION CORRESPONDING SELECT 1 AS x, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
assertThat(assertions.query("SELECT 1 AS x, 2 AS x UNION CORRESPONDING SELECT 1 AS x, 2 AS y"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");

assertThat(assertions.query("SELECT 1 AS x, 2 AS y INTERSECT CORRESPONDING SELECT 1 AS x, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
assertThat(assertions.query("SELECT 1 AS x, 2 AS x INTERSECT CORRESPONDING SELECT 1 AS x, 2 AS y"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
}

@Test
void testCorrespondingNameMismatch()
{
assertThat(assertions.query("SELECT 1 AS x EXCEPT CORRESPONDING SELECT 2 AS y"))
.failure().hasMessage("line 1:15: Column 'x' cannot be resolved");

assertThat(assertions.query("SELECT 1 AS x UNION CORRESPONDING SELECT 2 AS y"))
.failure().hasMessage("line 1:15: Column 'x' cannot be resolved");

assertThat(assertions.query("SELECT 1 AS x INTERSECT CORRESPONDING SELECT 2 AS y"))
.failure().hasMessage("line 1:15: Column 'x' cannot be resolved");
}

@Test
void testCorrespondingWithAnonymousColumn()
{
assertThat(assertions.query("SELECT 1 EXCEPT CORRESPONDING SELECT 2 AS x"))
.failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING");
assertThat(assertions.query("SELECT 1 AS x EXCEPT CORRESPONDING SELECT 2"))
.failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING");

assertThat(assertions.query("SELECT 1 UNION CORRESPONDING SELECT 2 AS x"))
.failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING");
assertThat(assertions.query("SELECT 1 AS x UNION CORRESPONDING SELECT 2"))
.failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING");

assertThat(assertions.query("SELECT 1 INTERSECT CORRESPONDING SELECT 2 AS x"))
.failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING");
assertThat(assertions.query("SELECT 1 AS x INTERSECT CORRESPONDING SELECT 2"))
.failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,9 @@ protected Void visitUnion(Union node, Integer indent)
if (!node.isDistinct()) {
builder.append("ALL ");
}
if (node.isCorresponding()) {
builder.append("CORRESPONDING ");
}
}
}

Expand All @@ -1049,6 +1052,9 @@ protected Void visitExcept(Except node, Integer indent)
if (!node.isDistinct()) {
builder.append("ALL ");
}
if (node.isCorresponding()) {
builder.append("CORRESPONDING ");
}

processRelation(node.getRight(), indent);

Expand All @@ -1068,6 +1074,9 @@ protected Void visitIntersect(Intersect node, Integer indent)
if (!node.isDistinct()) {
builder.append("ALL ");
}
if (node.isCorresponding()) {
builder.append("CORRESPONDING ");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,11 +1383,12 @@ public Node visitSetOperation(SqlBaseParser.SetOperationContext context)
QueryBody right = (QueryBody) visit(context.right);

boolean distinct = context.setQuantifier() == null || context.setQuantifier().DISTINCT() != null;
boolean corresponding = context.CORRESPONDING() != null;

return switch (context.operator.getType()) {
case SqlBaseLexer.UNION -> new Union(getLocation(context.UNION()), ImmutableList.of(left, right), distinct);
case SqlBaseLexer.INTERSECT -> new Intersect(getLocation(context.INTERSECT()), ImmutableList.of(left, right), distinct);
case SqlBaseLexer.EXCEPT -> new Except(getLocation(context.EXCEPT()), left, right, distinct);
case SqlBaseLexer.UNION -> new Union(getLocation(context.UNION()), ImmutableList.of(left, right), distinct, corresponding);
case SqlBaseLexer.INTERSECT -> new Intersect(getLocation(context.INTERSECT()), ImmutableList.of(left, right), distinct, corresponding);
case SqlBaseLexer.EXCEPT -> new Except(getLocation(context.EXCEPT()), left, right, distinct, corresponding);
default -> throw new IllegalArgumentException("Unsupported set operation: " + context.operator.getText());
};
}
Expand Down
14 changes: 9 additions & 5 deletions core/trino-parser/src/main/java/io/trino/sql/tree/Except.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ public class Except
private final Relation left;
private final Relation right;

public Except(NodeLocation location, Relation left, Relation right, boolean distinct)
public Except(NodeLocation location, Relation left, Relation right, boolean distinct, boolean corresponding)
{
super(Optional.of(location), distinct);
super(Optional.of(location), distinct, corresponding);
requireNonNull(left, "left is null");
requireNonNull(right, "right is null");

Expand Down Expand Up @@ -73,6 +73,7 @@ public String toString()
.add("left", left)
.add("right", right)
.add("distinct", isDistinct())
.add("corresponding", isCorresponding())
.toString();
}

Expand All @@ -88,13 +89,14 @@ public boolean equals(Object obj)
Except o = (Except) obj;
return Objects.equals(left, o.left) &&
Objects.equals(right, o.right) &&
isDistinct() == o.isDistinct();
isDistinct() == o.isDistinct() &&
isCorresponding() == o.isCorresponding();
}

@Override
public int hashCode()
{
return Objects.hash(left, right, isDistinct());
return Objects.hash(left, right, isDistinct(), isCorresponding());
}

@Override
Expand All @@ -104,6 +106,8 @@ public boolean shallowEquals(Node other)
return false;
}

return this.isDistinct() == ((Except) other).isDistinct();
Except otherExcept = (Except) other;
return this.isDistinct() == otherExcept.isDistinct() &&
this.isCorresponding() == otherExcept.isCorresponding();
}
}
Loading

0 comments on commit 3bb218f

Please sign in to comment.