diff --git a/docs/source/api/spatial-functions.rst b/docs/source/api/spatial-functions.rst index 73be0ce27..2158f397c 100644 --- a/docs/source/api/spatial-functions.rst +++ b/docs/source/api/spatial-functions.rst @@ -1721,7 +1721,7 @@ st_x .. function:: st_x(col) - Returns the x coordinate of the input geometry. + Returns the x coordinate of the centroid point of the input geometry. :param col: Geometry :type col: Column @@ -1880,7 +1880,7 @@ st_y **** .. function:: st_y(col) - Returns the y coordinate of the input geometry. + Returns the y coordinate of the centroid point of the input geometry. :param col: Geometry :type col: Column @@ -2036,6 +2036,59 @@ st_ymin +-----------------+ +st_z +**** +.. function:: st_z(col) + + Returns the z coordinate of an arbitrary point of the input geometry `geom`. + + :param col: Point Geometry + :type col: Column + :rtype: Column: DoubleType + + :example: + +.. tabs:: + .. code-tab:: py + + df = spark.createDataFrame([{'wkt': 'POINT (30 10 20)'}]) + df.select(st_z('wkt')).show() + +-----------------+ + |st_z(wkt) | + +-----------------+ + | 20.0| + +-----------------+ + + .. code-tab:: scala + + val df = List(("POINT (30 10 20)")).toDF("wkt") + df.select(st_z(col("wkt"))).show() + +-----------------+ + |st_z(wkt) | + +-----------------+ + | 20.0| + +-----------------+ + + .. code-tab:: sql + + SELECT st_z("POINT (30 10 20)") + +-----------------+ + |st_z(wkt) | + +-----------------+ + | 20.0| + +-----------------+ + + .. code-tab:: r R + + df <- createDataFrame(data.frame(wkt = "POINT (30 10 20)")) + showDF(select(df, st_z(column("wkt"))), truncate=F) + +-----------------+ + |st_z(wkt) | + +-----------------+ + | 20.0| + +-----------------+ + + st_zmax ******* diff --git a/python/mosaic/api/functions.py b/python/mosaic/api/functions.py index 9819caa5f..e2165195b 100644 --- a/python/mosaic/api/functions.py +++ b/python/mosaic/api/functions.py @@ -50,6 +50,7 @@ "st_zmax", "st_x", "st_y", + "st_z", "flatten_polygons", "grid_boundaryaswkb", "grid_boundary", @@ -753,7 +754,7 @@ def st_updatesrid( def st_x(geom: ColumnOrName) -> Column: """ - Returns the x coordinate of the input geometry `geom`. + Returns the x coordinate of the centroid point of the input geometry `geom`. Parameters ---------- @@ -769,7 +770,7 @@ def st_x(geom: ColumnOrName) -> Column: def st_y(geom: ColumnOrName) -> Column: """ - Returns the y coordinate of the input geometry `geom`. + Returns the y coordinate of the centroid point of the input geometry `geom`. Parameters ---------- @@ -783,6 +784,22 @@ def st_y(geom: ColumnOrName) -> Column: return config.mosaic_context.invoke_function("st_y", pyspark_to_java_column(geom)) +def st_z(geom: ColumnOrName) -> Column: + """ + Returns the z coordinate of an arbitrary point of the input geometry `geom`. + + Parameters + ---------- + geom : Column + + Returns + ------- + Column (DoubleType) + + """ + return config.mosaic_context.invoke_function("st_z", pyspark_to_java_column(geom)) + + def st_geometrytype(geom: ColumnOrName) -> Column: """ Returns the type of the input geometry `geom` (“POINT”, “LINESTRING”, “POLYGON” etc.). diff --git a/python/test/test_vector_functions.py b/python/test/test_vector_functions.py index 3a127327a..8dbb191a9 100644 --- a/python/test/test_vector_functions.py +++ b/python/test/test_vector_functions.py @@ -1,6 +1,6 @@ import random -from pyspark.sql.functions import abs, col, first, lit, sqrt +from pyspark.sql.functions import abs, col, concat, first, lit, sqrt from .context import api from .utils import MosaicTestCase @@ -27,6 +27,25 @@ def test_st_point(self): ) self.assertListEqual([rw.points for rw in result], expected) + def test_st_z(self): + expected = [ + 0, + 1, + ] + result = ( + self.spark.range(2) + .select(col("id").cast("double")) + .withColumn( + "points", + api.st_geomfromwkt( + concat(lit("POINT (9 9 "), "id", lit(")")) + ), + ) + .withColumn("z", api.st_z("points")) + .collect() + ) + self.assertListEqual([rw.z for rw in result], expected) + def test_st_bindings_happy_flow(self): # Checks that the python bindings do not throw exceptions # Not testing the logic, since that is tested in Scala diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala index 0093aa7c5..3c639446e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala @@ -41,6 +41,8 @@ trait MosaicGeometry extends GeometryWriter with Serializable { def getCentroid: MosaicPoint + def getAnyPoint: MosaicPoint + def getDimension: Int def isEmpty: Boolean diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala index d509bc9ec..e860d0b67 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala @@ -55,6 +55,16 @@ abstract class MosaicGeometryJTS(geom: Geometry) extends MosaicGeometry { MosaicPointJTS(centroid) } + override def getAnyPoint: MosaicPointJTS = { + // while this doesn't return the centroid but an arbitrary point via getCoordinate in JTS, + // inlike getCentroid this supports a Z coordinate. + + val coord = geom.getCoordinate + val gf = new GeometryFactory() + val point = gf.createPoint(coord) + MosaicPointJTS(point) + } + override def isEmpty: Boolean = geom.isEmpty override def boundary: MosaicGeometryJTS = MosaicGeometryJTS(geom.getBoundary) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_X.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_X.scala index f95af0921..bfb87ad71 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_X.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_X.scala @@ -44,7 +44,7 @@ object ST_X extends WithExpressionInfo { override def name: String = "st_x" override def usage: String = - "_FUNC_(expr1) - Returns x coordinate of a point or x coordinate of the centroid if the geometry isnt a point." + "_FUNC_(expr1) - Returns x coordinate of a point or x coordinate of the centroid if the geometry isn't a point." override def example: String = """ diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Y.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Y.scala index ff626ec0b..6e62fa7bf 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Y.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Y.scala @@ -44,7 +44,7 @@ object ST_Y extends WithExpressionInfo { override def name: String = "st_y" override def usage: String = - "_FUNC_(expr1) - Returns y coordinate of a point or y coordinate of the centroid if the geometry isnt a point." + "_FUNC_(expr1) - Returns y coordinate of a point or y coordinate of the centroid if the geometry isn't a point." override def example: String = """ diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Z.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Z.scala new file mode 100644 index 000000000..ea4e0d67a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Z.scala @@ -0,0 +1,57 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.core.geometry.MosaicGeometry +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.geometry.base.UnaryVectorExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.{DataType, DoubleType} + +/** + * SQL expression that returns Z coordinate of the input point. Input must be a point. + * + * @param inputGeom + * Expression containing the geometry. + * @param expressionConfig + * Mosaic execution context, e.g. geometryAPI, indexSystem, etc. Additional + * arguments for the expression (expressionConfigs). + */ +case class ST_Z( + inputGeom: Expression, + expressionConfig: MosaicExpressionConfig +) extends UnaryVectorExpression[ST_Z](inputGeom, returnsGeometry = false, expressionConfig) { + + override def dataType: DataType = DoubleType + + override def geometryTransform(geometry: MosaicGeometry): Any = geometry.getAnyPoint.getZ + + override def geometryCodeGen(geometryRef: String, ctx: CodegenContext): (String, String) = { + val resultRef = ctx.freshName("result") + val code = s"""double $resultRef = $geometryRef.getAnyPoint().getZ();""" + (code, resultRef) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object ST_Z extends WithExpressionInfo { + + override def name: String = "st_z" + + override def usage: String = + "_FUNC_(expr1) - Returns z coordinate of a point or z coordinate of an arbitrary point in geometry if it isn't a point." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(a); + | 12.3 + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[ST_Z](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 596d1e2a7..905d6962e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -181,6 +181,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[ST_Within](expressionConfig) mosaicRegistry.registerExpression[ST_X](expressionConfig) mosaicRegistry.registerExpression[ST_Y](expressionConfig) + mosaicRegistry.registerExpression[ST_Z](expressionConfig) mosaicRegistry.registerExpression[ST_Haversine](expressionConfig) // noinspection ScalaDeprecation @@ -600,6 +601,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(ST_Translate(geom1.expr, xd.expr, yd.expr, expressionConfig)) def st_x(geom: Column): Column = ColumnAdapter(ST_X(geom.expr, expressionConfig)) def st_y(geom: Column): Column = ColumnAdapter(ST_Y(geom.expr, expressionConfig)) + def st_z(geom: Column): Column = ColumnAdapter(ST_Z(geom.expr, expressionConfig)) def st_xmax(geom: Column): Column = ColumnAdapter(ST_MinMaxXYZ(geom.expr, expressionConfig, "X", "MAX")) def st_xmin(geom: Column): Column = ColumnAdapter(ST_MinMaxXYZ(geom.expr, expressionConfig, "X", "MIN")) def st_ymax(geom: Column): Column = ColumnAdapter(ST_MinMaxXYZ(geom.expr, expressionConfig, "Y", "MAX")) diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZBehaviors.scala new file mode 100644 index 000000000..3ce7ad57d --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZBehaviors.scala @@ -0,0 +1,87 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.functions.MosaicContext +import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.scalatest.matchers.must.Matchers.noException +import org.scalatest.matchers.should.Matchers.{an, be, convertToAnyShouldWrapper} + +trait ST_ZBehaviors extends MosaicSpatialQueryTest { + + def stzBehavior(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + import mc.functions._ + val sc = spark + import sc.implicits._ + mc.register(spark) + + val rows = List( + ("POINT (2 3 5)", 5), + ("POINT (7 11 13)", 13), + ("POINT (17 19 23)", 23), + ("POINT (29 31 37)", 37) + ) + + val result = rows + .toDF("wkt", "expected") + .withColumn("result", st_z($"wkt")) + .where($"expected" === $"result") + + result.count shouldBe 4 + } + + def stzCodegen(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + val sc = spark + import mc.functions._ + import sc.implicits._ + mc.register(spark) + + val rows = List( + ("POINT (2 3 5)", 5), + ("POINT (7 11 13)", 13), + ("POINT (17 19 23)", 23), + ("POINT (29 31 37)", 37) + ) + + val points = rows.toDF("wkt", "expected") + + val result = points + .withColumn("result", st_z($"wkt")) + .where($"expected" === $"result") + + val queryExecution = result.queryExecution + val plan = queryExecution.executedPlan + + val wholeStageCodegenExec = plan.find(_.isInstanceOf[WholeStageCodegenExec]) + + wholeStageCodegenExec.isDefined shouldBe true + + val codeGenStage = wholeStageCodegenExec.get.asInstanceOf[WholeStageCodegenExec] + val (_, code) = codeGenStage.doCodeGen() + + noException should be thrownBy CodeGenerator.compile(code) + + val stZ = ST_Z(lit(1).expr, mc.expressionConfig) + val ctx = new CodegenContext + an[Error] should be thrownBy stZ.genCode(ctx) + } + + def auxiliaryMethods(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + mc.register(spark) + + val stZ = ST_Z(lit("POINT (2 3 4)").expr, mc.expressionConfig) + + stZ.child shouldEqual lit("POINT (2 3 4)").expr + stZ.dataType shouldEqual DoubleType + noException should be thrownBy stZ.makeCopy(Array(stZ.child)) + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZTest.scala new file mode 100644 index 000000000..4e2e99b26 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZTest.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest +import org.apache.spark.sql.test.SharedSparkSession + +class ST_ZTest extends MosaicSpatialQueryTest with SharedSparkSession with ST_ZBehaviors { + + testAllGeometriesNoCodegen("Testing stZ NO_CODEGEN") { stzBehavior } + testAllGeometriesCodegen("Testing stZ CODEGEN") { stzBehavior } + testAllGeometriesCodegen("Testing stZ CODEGEN compilation") { stzCodegen } + testAllGeometriesNoCodegen("Testing stZ auxiliary methods") { auxiliaryMethods } + +}