Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix table function execution without partitioning #21378

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public static class TableFunctionOperatorFactory
private final boolean pruneWhenEmpty;

// partitioning channels from all sources
private final List<Integer> partitionChannels;
private final Optional<List<Integer>> partitionChannels;

// subset of partition channels that are already grouped
private final List<Integer> prePartitionedChannels;
Expand Down Expand Up @@ -117,7 +117,7 @@ public TableFunctionOperatorFactory(
Optional<Map<Integer, Integer>> markerChannels,
List<PassThroughColumnSpecification> passThroughSpecifications,
boolean pruneWhenEmpty,
List<Integer> partitionChannels,
Optional<List<Integer>> partitionChannels,
List<Integer> prePartitionedChannels,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
Expand All @@ -134,12 +134,20 @@ public TableFunctionOperatorFactory(
requireNonNull(passThroughSpecifications, "passThroughSpecifications is null");
requireNonNull(partitionChannels, "partitionChannels is null");
requireNonNull(prePartitionedChannels, "prePartitionedChannels is null");
checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
requireNonNull(sortChannels, "sortChannels is null");
requireNonNull(sortOrders, "sortOrders is null");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
if (partitionChannels.isPresent()) {
checkArgument(partitionChannels.get().containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels.get())), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
}
else {
checkArgument(prePartitionedChannels.isEmpty(), "prePartitionedChannels must be empty when partitionChannels is absent");
checkArgument(sortChannels.isEmpty(), "sortChannels must be empty when partitionChannels is absent");
checkArgument(sortOrders.isEmpty(), "sortOrders must be empty when partitionChannels is absent");
checkArgument(preSortedPrefix == 0, "preSortedPrefix must be zero when partitionChannels is absent");
}
requireNonNull(sourceTypes, "sourceTypes is null");
requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");

Expand All @@ -156,7 +164,7 @@ public TableFunctionOperatorFactory(
this.markerChannels = markerChannels.map(ImmutableMap::copyOf);
this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications);
this.pruneWhenEmpty = pruneWhenEmpty;
this.partitionChannels = ImmutableList.copyOf(partitionChannels);
this.partitionChannels = partitionChannels.map(ImmutableList::copyOf);
this.prePartitionedChannels = ImmutableList.copyOf(prePartitionedChannels);
this.sortChannels = ImmutableList.copyOf(sortChannels);
this.sortOrders = ImmutableList.copyOf(sortOrders);
Expand Down Expand Up @@ -242,7 +250,7 @@ public TableFunctionOperator(
Optional<Map<Integer, Integer>> markerChannels,
List<PassThroughColumnSpecification> passThroughSpecifications,
boolean pruneWhenEmpty,
List<Integer> partitionChannels,
Optional<List<Integer>> partitionChannels,
List<Integer> prePartitionedChannels,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
Expand All @@ -260,12 +268,20 @@ public TableFunctionOperator(
requireNonNull(passThroughSpecifications, "passThroughSpecifications is null");
requireNonNull(partitionChannels, "partitionChannels is null");
requireNonNull(prePartitionedChannels, "prePartitionedChannels is null");
checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
requireNonNull(sortChannels, "sortChannels is null");
requireNonNull(sortOrders, "sortOrders is null");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
if (partitionChannels.isPresent()) {
checkArgument(partitionChannels.get().containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels.get())), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
}
else {
checkArgument(prePartitionedChannels.isEmpty(), "prePartitionedChannels must be empty when partitionChannels is absent");
checkArgument(sortChannels.isEmpty(), "sortChannels must be empty when partitionChannels is absent");
checkArgument(sortOrders.isEmpty(), "sortOrders must be empty when partitionChannels is absent");
checkArgument(preSortedPrefix == 0, "preSortedPrefix must be zero when partitionChannels is absent");
}
requireNonNull(sourceTypes, "sourceTypes is null");
requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");

Expand All @@ -275,23 +291,42 @@ public TableFunctionOperator(
this.processEmptyInput = !pruneWhenEmpty;

PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions);
HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels, prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix);

this.outputPages = pageBuffer.pages()
.transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput))
.flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions(
groupPagesIndex,
hashStrategies,
tableFunctionProvider,
session,
functionHandle,
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
markerChannels,
passThroughSpecifications,
processEmptyInput))
.flatMap(TableFunctionPartition::toOutputPages);
if (partitionChannels.isEmpty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to to check if the List is empty as well? I tested locally on some custom table functions by cherry-picking this commit and ran into this case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that empty list is legitimate, and requires global aggregation (one partition)

this.outputPages = pageBuffer.pages()
.map(page -> {
pagesIndex.clear();
pagesIndex.addPage(page);
return new RegularTableFunctionPartition(
pagesIndex,
0,
pagesIndex.getPositionCount(),
tableFunctionProvider.getDataProcessor(session, functionHandle),
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
markerChannels,
passThroughSpecifications);
})
.flatMap(TableFunctionPartition::toOutputPages);
}
else {
HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels.get(), prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix);
this.outputPages = pageBuffer.pages()
.transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput))
.flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions(
groupPagesIndex,
hashStrategies,
tableFunctionProvider,
session,
functionHandle,
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
markerChannels,
passThroughSpecifications,
processEmptyInput))
.flatMap(TableFunctionPartition::toOutputPages);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1660,10 +1660,9 @@ public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode
}
}

List<Integer> partitionChannels = node.getSpecification()
Optional<List<Integer>> partitionChannels = node.getSpecification()
.map(DataOrganizationSpecification::partitionBy)
.map(list -> getChannelsForSymbols(list, source.getLayout()))
.orElse(ImmutableList.of());
.map(list -> getChannelsForSymbols(list, source.getLayout()));

List<Integer> sortChannels = ImmutableList.of();
List<SortOrder> sortOrders = ImmutableList.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,20 +514,15 @@ public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode
PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(), childRequirements, childRequirements);

List<LocalProperty<Symbol>> desiredProperties = new ArrayList<>();
if (!partitionBy.isEmpty()) {
desiredProperties.add(new GroupingProperty<>(partitionBy));
}
desiredProperties.add(new GroupingProperty<>(partitionBy));
node.getSpecification().flatMap(DataOrganizationSpecification::orderingScheme).ifPresent(orderingScheme -> desiredProperties.addAll(orderingScheme.toLocalProperties()));
Iterator<Optional<LocalProperty<Symbol>>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator();

Set<Symbol> prePartitionedInputs = ImmutableSet.of();
if (!partitionBy.isEmpty()) {
Optional<LocalProperty<Symbol>> groupingRequirement = matchIterator.next();
Set<Symbol> unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of());
prePartitionedInputs = partitionBy.stream()
.filter(symbol -> !unPartitionedInputs.contains(symbol))
.collect(toImmutableSet());
}
Optional<LocalProperty<Symbol>> groupingRequirement = matchIterator.next();
Set<Symbol> unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of());
Set<Symbol> prePartitionedInputs = partitionBy.stream()
.filter(symbol -> !unPartitionedInputs.contains(symbol))
.collect(toImmutableSet());

int preSortedOrderPrefix = 0;
if (prePartitionedInputs.equals(ImmutableSet.copyOf(partitionBy))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.DynamicFilterSourceNode;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
Expand Down Expand Up @@ -359,11 +358,8 @@ public ActualProperties visitTableFunctionProcessor(TableFunctionProcessorNode n
}
}

List<Symbol> partitionBy = node.getSpecification()
.map(DataOrganizationSpecification::partitionBy)
.orElse(ImmutableList.of());
if (!partitionBy.isEmpty()) {
localProperties.add(new GroupingProperty<>(partitionBy));
if (node.getSpecification().isPresent()) {
localProperties.add(new GroupingProperty<>(node.getSpecification().orElseThrow().partitionBy()));
}

// TODO add global single stream property when there's Specification present with no partitioning columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public record DataOrganizationSpecification(
Expand All @@ -31,4 +32,13 @@ public record DataOrganizationSpecification(
partitionBy = ImmutableList.copyOf(partitionBy);
requireNonNull(orderingScheme, "orderingScheme is null");
}

@Override
public String toString()
{
return toStringHelper(this)
.add("partitionBy", partitionBy)
.add("orderingScheme", orderingScheme)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ public TableFunctionProcessorNode(
.map(OrderingScheme::orderBy)
.map(List::size)
.orElse(0) >= preSorted,
"the number of pre-sorted symbols cannot be greater than the number of all ordering symbols");
"the number of pre-sorted symbols %s cannot be greater than the number of all ordering symbols from specification %s",
preSorted,
specification);
checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned");
this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null");
this.handle = requireNonNull(handle, "handle is null");
Expand Down