Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Feb 6, 2025
1 parent 6224de2 commit 2e954b2
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,12 @@ class FlintSparkPPLLookupITSuite
test("test LOOKUP lookupTable uid AS id, name") {
val frame = sql(s"source = $sourceTable| LOOKUP $lookupTable uID AS id, name")
val expectedResults: Array[Row] = Array(
Row(1000, "Jake", "Engineer", "England", 100000, "IT", "Engineer"),
Row(1001, "Hello", "Artist", "USA", 70000, null, null),
Row(1002, "John", "Doctor", "Canada", 120000, "DATA", "Scientist"),
Row(1003, "David", "Doctor", null, 120000, "HR", "Doctor"),
Row(1004, "David", null, "Canada", 0, null, null),
Row(1005, "Jane", "Scientist", "Canada", 90000, "DATA", "Engineer"))
Row(1000, "Jake", "England", 100000, "IT", "Engineer"),
Row(1001, "Hello", "USA", 70000, null, null),
Row(1002, "John", "Canada", 120000, "DATA", "Scientist"),
Row(1003, "David", null, 120000, "HR", "Doctor"),
Row(1004, "David", "Canada", 0, null, null),
Row(1005, "Jane", "Canada", 90000, "DATA", "Engineer"))

assertSameRows(expectedResults, frame)
}
Expand Down Expand Up @@ -420,12 +420,12 @@ class FlintSparkPPLLookupITSuite
val frame =
sql(s"source = $sourceTable | LOOKUP $lookupTable name")
val expectedResults: Array[Row] = Array(
Row(1000, "Jake", "Engineer", "England", 100000, 1000, "IT", "Engineer"),
Row(1001, "Hello", "Artist", "USA", 70000, null, null, null),
Row(1002, "John", "Doctor", "Canada", 120000, 1002, "DATA", "Scientist"),
Row(1003, "David", "Doctor", null, 120000, 1003, "HR", "Doctor"),
Row(1004, "David", null, "Canada", 0, 1003, "HR", "Doctor"),
Row(1005, "Jane", "Scientist", "Canada", 90000, 1005, "DATA", "Engineer"))
Row(1000, "Jake", "England", 100000, 1000, "IT", "Engineer"),
Row(1001, "Hello", "USA", 70000, null, null, null),
Row(1002, "John", "Canada", 120000, 1002, "DATA", "Scientist"),
Row(1003, "David", null, 120000, 1003, "HR", "Doctor"),
Row(1004, "David", "Canada", 0, 1003, "HR", "Doctor"),
Row(1005, "Jane", "Canada", 90000, 1005, "DATA", "Engineer"))
assertSameRows(expectedResults, frame)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.ppl;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
Expand Down Expand Up @@ -79,6 +81,7 @@
import org.opensearch.sql.ppl.utils.FieldSummaryTransformer;
import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator;
import org.opensearch.sql.ppl.utils.ParseTransformer;
import org.opensearch.sql.ppl.utils.RelationUtils;
import org.opensearch.sql.ppl.utils.SortUtils;
import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils;
import org.opensearch.sql.ppl.utils.WindowSpecTransformer;
Expand All @@ -88,9 +91,11 @@
import scala.collection.Seq;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -122,6 +127,7 @@
* Utility class to traverse PPL logical plan and translate it into catalyst logical plan
*/
public class CatalystQueryPlanVisitor extends AbstractNodeVisitor<LogicalPlan, CatalystPlanContext> {
private static final Logger LOG = LogManager.getLogger(CatalystQueryPlanVisitor.class);

private final CatalystExpressionVisitor expressionAnalyzer;

Expand Down Expand Up @@ -193,16 +199,34 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) {
public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
visitFirstChild(node, context);
return context.apply( searchSide -> {
context.retainAllNamedParseExpressions(p -> p);
context.retainAllPlans(p -> p);
LogicalPlan target;
LogicalPlan lookupTable = node.getLookupRelation().accept(this, context);
Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context);
// If no output field is specified, all fields except mapping fields from lookup table are applied to the output.
if (node.allFieldsShouldAppliedToOutputList()) {
context.retainAllNamedParseExpressions(p -> p);
context.retainAllPlans(p -> p);
target = join(searchSide, lookupTable, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint());
// When no output field is specified, all fields except mapping fields from lookup table are applied to the output.
// If some output fields from source side duplicate to fields of lookup table, these fields will
// be replaced by fields from lookup table in output.
// For example, the lookup table contains fields [id, col1, col3] and source side fields are [id, col1, col2].
// For query "index = sourceTable | fields id, col1, col2 | LOOKUP lookupTable id",
// the col1 is duplicated field and id is mapping key (and duplicated).
// The query outputs 4 fields: [id, col1, col2, col3]. Among them, `col2` is the original field from source,
// the matched values of `col1` from lookup table will replace to the values of `col1` from source.
Set<Field> intersection = new HashSet<>(RelationUtils.getFieldsFromCatalogTable(context.getSparkSession(), lookupTable));
LOG.debug("The fields list in lookup table are {}", intersection);
List<Field> sourceOutput = RelationUtils.getFieldsFromCatalogTable(context.getSparkSession(), searchSide);
LOG.debug("The fields list in source output are {}", sourceOutput);
Set<Field> mappingFieldsOfLookup = node.getLookupMappingMap().keySet();
// lookup mapping keys are not concerned to drop here, it will be checked later.
intersection.removeAll(mappingFieldsOfLookup);
intersection.retainAll(sourceOutput);
List<Expression> duplicated = buildProjectListFromFields(new ArrayList<>(intersection), expressionAnalyzer, context)
.stream().map(e -> (Expression) e).collect(Collectors.toList());
LogicalPlan searchSideWithDropped = DataFrameDropColumns$.MODULE$.apply(seq(duplicated), searchSide);
target = join(searchSideWithDropped, lookupTable, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint());
} else {
// If the output fields are specified, build a project list for lookup table.
// When output fields are specified, build a project list for lookup table.
// The mapping fields of lookup table should be added in this project list, otherwise join will fail.
// So the mapping fields of lookup table should be dropped after join.
List<NamedExpression> lookupTableProjectList = buildLookupRelationProjectList(node, expressionAnalyzer, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ static List<NamedExpression> buildOutputProjectList(
// If not, resolve the outputCol expression without alias: <fieldName> to avoid failure of unable to resolved attribute.
Expression inputCol = expressionAnalyzer.visitField(buildFieldWithLookupSubqueryAlias(node, inputField), context);
Expression outputCol;
if (RelationUtils.columnExistsInCatalogTable(context.getSparkSession(), outputField, searchSide)) {
if (RelationUtils.columnExistsInCatalogTable(context.getSparkSession(), searchSide, outputField)) {
outputCol = expressionAnalyzer.visitField(buildFieldWithSourceSubqueryAlias(node, outputField), context);
} else {
outputCol = expressionAnalyzer.visitField(outputField, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.sql.ppl.utils;

import com.google.common.collect.ImmutableList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException;
Expand All @@ -20,10 +23,10 @@
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public interface RelationUtils {
Logger LOG = Logger.getLogger(RelationUtils.class.getName());
Logger LOG = LogManager.getLogger(RelationUtils.class);

/**
* attempt resolving if the field is relating to the given relation
Expand Down Expand Up @@ -76,22 +79,26 @@ static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) {
return identifier;
}

static boolean columnExistsInCatalogTable(SparkSession spark, Field field, LogicalPlan plan) {
static boolean columnExistsInCatalogTable(SparkSession spark, LogicalPlan plan, Field field) {
return getFieldsFromCatalogTable(spark, plan).stream().anyMatch(f -> f.getField().equals(field.getField()));
}

static List<Field> getFieldsFromCatalogTable(SparkSession spark, LogicalPlan plan) {
UnresolvedRelation relation = PPLSparkUtils.findLogicalRelations(plan).head();
QualifiedName tableQualifiedName = QualifiedName.of(Arrays.asList(relation.tableName().split("\\.")));
TableIdentifier sourceTableIdentifier = getTableIdentifier(tableQualifiedName);
boolean sourceTableExists = spark.sessionState().catalog().tableExists(sourceTableIdentifier);
if (sourceTableExists) {
TableIdentifier tableIdentifier = getTableIdentifier(tableQualifiedName);
boolean tableExists = spark.sessionState().catalog().tableExists(tableIdentifier);
if (tableExists) {
try {
CatalogTable table = spark.sessionState().catalog().getTableMetadata(getTableIdentifier(tableQualifiedName));
return Arrays.stream(table.dataSchema().fields()).anyMatch(f -> f.name().equalsIgnoreCase(field.getField().toString()));
return Arrays.stream(table.dataSchema().fields()).map(f -> new Field(QualifiedName.of(f.name()))).collect(Collectors.toList());
} catch (NoSuchDatabaseException | NoSuchTableException e) {
LOG.warning("Source table or database " + sourceTableIdentifier + " not found");
return false;
LOG.info("Table or database {} not found", tableIdentifier);
return ImmutableList.of();
}
} else {
LOG.warning("Source table " + sourceTableIdentifier + " not found");
return false;
LOG.info("Table {} not found", tableIdentifier);
return ImmutableList.of();
}
}
}

0 comments on commit 2e954b2

Please sign in to comment.