diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala index 171508c86..085bdf7fd 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLLookupITSuite.scala @@ -298,12 +298,12 @@ class FlintSparkPPLLookupITSuite test("test LOOKUP lookupTable uid AS id, name") { val frame = sql(s"source = $sourceTable| LOOKUP $lookupTable uID AS id, name") val expectedResults: Array[Row] = Array( - Row(1000, "Jake", "Engineer", "England", 100000, "IT", "Engineer"), - Row(1001, "Hello", "Artist", "USA", 70000, null, null), - Row(1002, "John", "Doctor", "Canada", 120000, "DATA", "Scientist"), - Row(1003, "David", "Doctor", null, 120000, "HR", "Doctor"), - Row(1004, "David", null, "Canada", 0, null, null), - Row(1005, "Jane", "Scientist", "Canada", 90000, "DATA", "Engineer")) + Row(1000, "Jake", "England", 100000, "IT", "Engineer"), + Row(1001, "Hello", "USA", 70000, null, null), + Row(1002, "John", "Canada", 120000, "DATA", "Scientist"), + Row(1003, "David", null, 120000, "HR", "Doctor"), + Row(1004, "David", "Canada", 0, null, null), + Row(1005, "Jane", "Canada", 90000, "DATA", "Engineer")) assertSameRows(expectedResults, frame) } @@ -420,12 +420,12 @@ class FlintSparkPPLLookupITSuite val frame = sql(s"source = $sourceTable | LOOKUP $lookupTable name") val expectedResults: Array[Row] = Array( - Row(1000, "Jake", "Engineer", "England", 100000, 1000, "IT", "Engineer"), - Row(1001, "Hello", "Artist", "USA", 70000, null, null, null), - Row(1002, "John", "Doctor", "Canada", 120000, 1002, "DATA", "Scientist"), - Row(1003, "David", "Doctor", null, 120000, 1003, "HR", "Doctor"), - Row(1004, "David", null, "Canada", 0, 1003, "HR", "Doctor"), - Row(1005, "Jane", "Scientist", "Canada", 90000, 1005, "DATA", "Engineer")) + Row(1000, "Jake", "England", 100000, 1000, "IT", "Engineer"), + Row(1001, "Hello", "USA", 70000, null, null, null), + Row(1002, "John", "Canada", 120000, 1002, "DATA", "Scientist"), + Row(1003, "David", null, 120000, 1003, "HR", "Doctor"), + Row(1004, "David", "Canada", 0, 1003, "HR", "Doctor"), + Row(1005, "Jane", "Canada", 90000, 1005, "DATA", "Engineer")) assertSameRows(expectedResults, frame) } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index b92575394..ec7e85d84 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -5,6 +5,8 @@ package org.opensearch.sql.ppl; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; @@ -79,6 +81,7 @@ import org.opensearch.sql.ppl.utils.FieldSummaryTransformer; import org.opensearch.sql.ppl.utils.GeoIpCatalystLogicalPlanTranslator; import org.opensearch.sql.ppl.utils.ParseTransformer; +import org.opensearch.sql.ppl.utils.RelationUtils; import org.opensearch.sql.ppl.utils.SortUtils; import org.opensearch.sql.ppl.utils.TrendlineCatalystUtils; import org.opensearch.sql.ppl.utils.WindowSpecTransformer; @@ -88,9 +91,11 @@ import scala.collection.Seq; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -122,6 +127,7 @@ * Utility class to traverse PPL logical plan and translate it into catalyst logical plan */ public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { + private static final Logger LOG = LogManager.getLogger(CatalystQueryPlanVisitor.class); private final CatalystExpressionVisitor expressionAnalyzer; @@ -193,16 +199,34 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { visitFirstChild(node, context); return context.apply( searchSide -> { + context.retainAllNamedParseExpressions(p -> p); + context.retainAllPlans(p -> p); LogicalPlan target; LogicalPlan lookupTable = node.getLookupRelation().accept(this, context); Expression lookupCondition = buildLookupMappingCondition(node, expressionAnalyzer, context); - // If no output field is specified, all fields except mapping fields from lookup table are applied to the output. if (node.allFieldsShouldAppliedToOutputList()) { - context.retainAllNamedParseExpressions(p -> p); - context.retainAllPlans(p -> p); - target = join(searchSide, lookupTable, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint()); + // When no output field is specified, all fields except mapping fields from lookup table are applied to the output. + // If some output fields from source side duplicate to fields of lookup table, these fields will + // be replaced by fields from lookup table in output. + // For example, the lookup table contains fields [id, col1, col3] and source side fields are [id, col1, col2]. + // For query "index = sourceTable | fields id, col1, col2 | LOOKUP lookupTable id", + // the col1 is duplicated field and id is mapping key (and duplicated). + // The query outputs 4 fields: [id, col1, col2, col3]. Among them, `col2` is the original field from source, + // the matched values of `col1` from lookup table will replace to the values of `col1` from source. + Set intersection = new HashSet<>(RelationUtils.getFieldsFromCatalogTable(context.getSparkSession(), lookupTable)); + LOG.debug("The fields list in lookup table are {}", intersection); + List sourceOutput = RelationUtils.getFieldsFromCatalogTable(context.getSparkSession(), searchSide); + LOG.debug("The fields list in source output are {}", sourceOutput); + Set mappingFieldsOfLookup = node.getLookupMappingMap().keySet(); + // lookup mapping keys are not concerned to drop here, it will be checked later. + intersection.removeAll(mappingFieldsOfLookup); + intersection.retainAll(sourceOutput); + List duplicated = buildProjectListFromFields(new ArrayList<>(intersection), expressionAnalyzer, context) + .stream().map(e -> (Expression) e).collect(Collectors.toList()); + LogicalPlan searchSideWithDropped = DataFrameDropColumns$.MODULE$.apply(seq(duplicated), searchSide); + target = join(searchSideWithDropped, lookupTable, Join.JoinType.LEFT, Optional.of(lookupCondition), new Join.JoinHint()); } else { - // If the output fields are specified, build a project list for lookup table. + // When output fields are specified, build a project list for lookup table. // The mapping fields of lookup table should be added in this project list, otherwise join will fail. // So the mapping fields of lookup table should be dropped after join. List lookupTableProjectList = buildLookupRelationProjectList(node, expressionAnalyzer, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java index d2de00545..a9c5e9f5e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/LookupTransformer.java @@ -95,7 +95,7 @@ static List buildOutputProjectList( // If not, resolve the outputCol expression without alias: to avoid failure of unable to resolved attribute. Expression inputCol = expressionAnalyzer.visitField(buildFieldWithLookupSubqueryAlias(node, inputField), context); Expression outputCol; - if (RelationUtils.columnExistsInCatalogTable(context.getSparkSession(), outputField, searchSide)) { + if (RelationUtils.columnExistsInCatalogTable(context.getSparkSession(), searchSide, outputField)) { outputCol = expressionAnalyzer.visitField(buildFieldWithSourceSubqueryAlias(node, outputField), context); } else { outputCol = expressionAnalyzer.visitField(outputField, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java index 8bcb9d393..fa9faf8af 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java @@ -5,6 +5,9 @@ package org.opensearch.sql.ppl.utils; +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException; @@ -20,10 +23,10 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; -import java.util.logging.Logger; +import java.util.stream.Collectors; public interface RelationUtils { - Logger LOG = Logger.getLogger(RelationUtils.class.getName()); + Logger LOG = LogManager.getLogger(RelationUtils.class); /** * attempt resolving if the field is relating to the given relation @@ -76,22 +79,26 @@ static TableIdentifier getTableIdentifier(QualifiedName qualifiedName) { return identifier; } - static boolean columnExistsInCatalogTable(SparkSession spark, Field field, LogicalPlan plan) { + static boolean columnExistsInCatalogTable(SparkSession spark, LogicalPlan plan, Field field) { + return getFieldsFromCatalogTable(spark, plan).stream().anyMatch(f -> f.getField().equals(field.getField())); + } + + static List getFieldsFromCatalogTable(SparkSession spark, LogicalPlan plan) { UnresolvedRelation relation = PPLSparkUtils.findLogicalRelations(plan).head(); QualifiedName tableQualifiedName = QualifiedName.of(Arrays.asList(relation.tableName().split("\\."))); - TableIdentifier sourceTableIdentifier = getTableIdentifier(tableQualifiedName); - boolean sourceTableExists = spark.sessionState().catalog().tableExists(sourceTableIdentifier); - if (sourceTableExists) { + TableIdentifier tableIdentifier = getTableIdentifier(tableQualifiedName); + boolean tableExists = spark.sessionState().catalog().tableExists(tableIdentifier); + if (tableExists) { try { CatalogTable table = spark.sessionState().catalog().getTableMetadata(getTableIdentifier(tableQualifiedName)); - return Arrays.stream(table.dataSchema().fields()).anyMatch(f -> f.name().equalsIgnoreCase(field.getField().toString())); + return Arrays.stream(table.dataSchema().fields()).map(f -> new Field(QualifiedName.of(f.name()))).collect(Collectors.toList()); } catch (NoSuchDatabaseException | NoSuchTableException e) { - LOG.warning("Source table or database " + sourceTableIdentifier + " not found"); - return false; + LOG.info("Table or database {} not found", tableIdentifier); + return ImmutableList.of(); } } else { - LOG.warning("Source table " + sourceTableIdentifier + " not found"); - return false; + LOG.info("Table {} not found", tableIdentifier); + return ImmutableList.of(); } } }