Skip to content

Commit

Permalink
Merge branch 'databrickslabs:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
a0x8o authored Apr 26, 2023
2 parents 3c60227 + 5acbc2e commit a9fcc5e
Show file tree
Hide file tree
Showing 15 changed files with 257 additions and 59 deletions.
21 changes: 15 additions & 6 deletions docs/source/api/raster-format-readers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ The output of the reader is a DataFrame with the following columns:
.. code-tab:: py

>>> df = spark.read.format("gdal")\
.option("driverName", "TIF")\
.option("driverName", "GTiff")\
.load("dbfs:/path/to/raster.tif")
>>> df.show()
+--------------------+-----+-----+---------+--------------------+--------------------+----+--------------------+
Expand All @@ -68,7 +68,7 @@ The output of the reader is a DataFrame with the following columns:
.. code-tab:: scala

>>> val df = spark.read.format("gdal")
.option("driverName", "TIF")
.option("driverName", "GTiff")
.load("dbfs:/path/to/raster.tif")
>>> df.show()
+--------------------+-----+-----+---------+--------------------+--------------------+----+--------------------+
Expand All @@ -77,6 +77,11 @@ The output of the reader is a DataFrame with the following columns:
|dbfs:/path/to/ra...| 100| 100| 1|{AREA_OR_POINT=Po...| null| 4326|+proj=longlat +da...|
+--------------------+-----+-----+---------+--------------------+--------------------+----+--------------------+

.. warning::
Issue 350: https://github.com/databrickslabs/mosaic/issues/350
The raster reader 'driverName' option has to match the names provided in the above list.
For example, if you want to read a GeoTiff file, you have to use the following option:
.option("driverName", "GTiff") instead of .option("driverName", "tif").


mos.read().format("raster_to_grid")
Expand All @@ -100,7 +105,6 @@ The reader supports the following options:
* subdatasetName - if the raster has subdatasets, select a specific subdataset by name (StringType)
* kRingInterpolate - if the raster pixels are larger than the grid cells, use k_ring interpolation with n = kRingInterpolate (IntegerType)


.. function:: mos.read().format("raster_to_grid").load(path)

Loads a GDAL raster file and returns the result as a DataFrame.
Expand All @@ -116,7 +120,7 @@ The reader supports the following options:
.. code-tab:: py

>>> df = mos.read().format("raster_to_grid")\
.option("fileExtension", "tif")\
.option("fileExtension", "*.tif")\
.option("resolution", "8")\
.option("combiner", "mean")\
.option("retile", "true")\
Expand All @@ -136,7 +140,7 @@ The reader supports the following options:
.. code-tab:: scala

>>> val df = MosaicContext.read.format("raster_to_grid")
.option("fileExtension", "tif")
.option("fileExtension", "*.tif")
.option("resolution", "8")
.option("combiner", "mean")
.option("retile", "true")
Expand All @@ -151,4 +155,9 @@ The reader supports the following options:
| 1| 2|0.1400000000000000|
| 1| 3|0.2464000000000000|
| 1| 4|0.2464000000000000|
+--------+--------+------------------+
+--------+--------+------------------+

.. warning::
Issue 350: https://github.com/databrickslabs/mosaic/issues/350
The option 'fileExtension' expects a wild card mask. Please use the following format: '*.tif' or equivalent for other formats.
If you use 'tif' without the wildcard the reader wont pick up any files and you will have empty table as a result.
53 changes: 47 additions & 6 deletions docs/source/api/spatial-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ st_bufferloop

Fig 1. ST_BufferLoop(geom, 0.02, 0.04)

st_centroid2D
*************
st_centroid2D [Deprecated]
**************************

.. function:: st_centroid2D(col)

Expand Down Expand Up @@ -291,17 +291,58 @@ st_centroid2D
|{25.454545454545453, 26.96969696969697}|
+---------------------------------------+

st_centroid3D
st_centroid
*************

.. function:: st_centroid3D(col)
.. function:: st_centroid(col)

Returns the x, y and z coordinates representing the centroid of the input geometry.
Returns the POINT geometry representing the centroid of the input geometry.

:param col: Geometry
:type col: Column
:rtype: Column: StructType[x: DoubleType, y: DoubleType, z: DoubleType]
:rtype: Column: Geometry

:example:

.. tabs::
.. code-tab:: py

>>> df = spark.createDataFrame([{'wkt': 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))'}])
>>> df.select(st_centroid('wkt')).show()
+---------------------------------------------+
|st_centroid(wkt) |
+---------------------------------------------+
|POINT (25.454545454545453, 26.96969696969697)|
+---------------------------------------------+

.. code-tab:: scala

>>> val df = List(("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")).toDF("wkt")
>>> df.select(st_centroid2D(col("wkt"))).show()
+---------------------------------------------+
|st_centroid(wkt) |
+---------------------------------------------+
|POINT (25.454545454545453, 26.96969696969697)|
+---------------------------------------------+

.. code-tab:: sql

>>> SELECT st_centroid("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")
+---------------------------------------------+
|st_centroid(wkt) |
+---------------------------------------------+
|POINT (25.454545454545453, 26.96969696969697)|
+---------------------------------------------+

.. code-tab:: r R

>>> df <- createDataFrame(data.frame(wkt = "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"))
>>> showDF(select(df, st_centroid2D(column("wkt"))), truncate=F)
+---------------------------------------------+
|st_centroid(wkt) |
+---------------------------------------------+
|POINT (25.454545454545453, 26.96969696969697)|
+---------------------------------------------+

st_convexhull
*************
Expand Down
15 changes: 15 additions & 0 deletions docs/source/usage/automatic-sql-registration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,18 @@ To test the installation, create a new Python notebook and run the following com
spark.sql("""show functions""").where("startswith(function, 'st_')").display()
You should see all the supported functions registered by Mosaic appear in the output.
.. warning::
Issue 317: https://github.com/databrickslabs/mosaic/issues/317
Mosaic jar needs to be installed via init script and not through the cluster UI.
Automatic SQL registration needs to happen at the cluster start up time when Spark context is created.
Cluster UI installed libraries are made available too late and the Automatic SQL registration
will not work, but there is no way to print an Error message in that case.
.. warning::
Issue 297: https://github.com/databrickslabs/mosaic/issues/297
Since Mosaic V0.3.6 Automatic SQL Registration can fail with the following error message:
"java.lang.Exception: spark.databricks.labs.mosaic.raster.api". This is due to a missing key in the spark
configuration. The issue has been fixed since Mosaic V0.3.10. For releases between V0.3.6 and V0.3.10
please add the following configuration to your cluster spark configs: (spark.databricks.labs.mosaic.raster.api, "GDAL"),
or alternatively in python/scala code: spark.conf.set("spark.databricks.labs.mosaic.raster.api", "GDAL")
7 changes: 7 additions & 0 deletions python/mosaic/api/enable.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def enable_mosaic(spark: SparkSession, dbutils=None) -> None:
spark._jsparkSession, optionModule.apply(None)
)

isSupported = config.mosaic_context._context.checkDBR(spark._jsparkSession)
if not isSupported:
print("DEPRECATION WARNING: Mosaic is not supported on the selected Databricks Runtime. \n")
print("DEPRECATION WARNING: Mosaic will stop working on this cluster from version v0.4.0+. \n")
print("Please use a Databricks Photon-enabled Runtime (for performance benefits) or Runtime ML (for spatial AI benefits). \n")


# Not yet added to the pyspark API
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
Expand Down
22 changes: 21 additions & 1 deletion python/mosaic/api/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"st_scale",
"st_rotate",
"st_centroid",
"st_centroid2D",
"st_numpoints",
"st_isvalid",
"st_distance",
Expand Down Expand Up @@ -415,6 +416,25 @@ def st_rotate(geom: ColumnOrName, td: ColumnOrName) -> Column:


def st_centroid(geom: ColumnOrName) -> Column:
"""
Returns the POINT geometry representing the centroid of `geom`.
Parameters
----------
geom : Column
Returns
-------
Column (WKT/WKB)
Coordinates of the centroid.
"""
return config.mosaic_context.invoke_function(
"st_centroid", pyspark_to_java_column(geom)
)


def st_centroid2D(geom: ColumnOrName) -> Column:
"""
Returns the x and y coordinates representing the centroid of `geom`.
Expand All @@ -429,7 +449,7 @@ def st_centroid(geom: ColumnOrName) -> Column:
"""
return config.mosaic_context.invoke_function(
"st_centroid", pyspark_to_java_column(geom)
"st_centroid2D", pyspark_to_java_column(geom)
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GeoDBFileFormat extends OGRFileFormat with Serializable {
options: Map[String, String],
hadoopConf: Configuration
): PartitionedFile => Iterator[InternalRow] = {
OGRFileFormat.buildReaderImpl(driverName, dataSchema, options)
OGRFileFormat.buildReaderImpl(driverName, dataSchema, requiredSchema, options)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,11 @@ class OGRFileFormat extends FileFormat with DataSourceRegister with Serializable
options: Map[String, String],
hadoopConf: Configuration
): PartitionedFile => Iterator[InternalRow] = {
val driverName = options.getOrElse("driverName", "")
buildReaderImpl(driverName, dataSchema, options)
}

override def buildReaderWithPartitionValues(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration
): PartitionedFile => Iterator[InternalRow] = {
val driverName = options.getOrElse("driverName", "")
// No column filter at the moment.
// To improve performance, we can filter columns in the OGR layer using requiredSchema.
super.buildReaderWithPartitionValues(sparkSession, dataSchema, partitionSchema, dataSchema, filters, options, hadoopConf)
buildReaderImpl(driverName, dataSchema, requiredSchema, options)
}

override def prepareWrite(
Expand All @@ -84,6 +73,11 @@ class OGRFileFormat extends FileFormat with DataSourceRegister with Serializable
//noinspection VarCouldBeVal
object OGRFileFormat extends Serializable {

def OGREmptyGeometry: Geometry = {
enableOGRDrivers()
ogr.CreateGeometryFromWkt("POINT EMPTY")
}

/**
* Get the layer from a data source. The method prioritizes the layer name
* over the layer number.
Expand Down Expand Up @@ -343,11 +337,19 @@ object OGRFileFormat extends Serializable {
val geoms = (0 until feature.GetGeomFieldCount())
.map(feature.GetGeomFieldRef)
.flatMap(f => {
f.FlattenTo2D()
Seq(
if (asWKB) f.ExportToWkb else f.ExportToWkt,
Try(f.GetSpatialReference.GetAuthorityCode(null)).getOrElse("0")
)
if (Option(f).isDefined) {
f.FlattenTo2D()
Seq(
if (asWKB) f.ExportToWkb else f.ExportToWkt,
Try(f.GetSpatialReference.GetAuthorityCode(null)).getOrElse("0")
)
} else {
Seq(
if (asWKB) OGREmptyGeometry.ExportToWkb else OGREmptyGeometry.ExportToWkt,
"0"
)
}

})
val values = fields ++ geoms
values.toArray
Expand Down Expand Up @@ -431,16 +433,19 @@ object OGRFileFormat extends Serializable {
*
* @param driverName
* the name of the OGR driver
* @param schema
* the schema of the file
* @param dataSchema
* the full schema of the file
* @param requiredSchema
* the schema of the file that is required for the query
* @param options
* the options to use for the reader
* @return
* a function that can be used to read the file
*/
def buildReaderImpl(
driverName: String,
schema: StructType,
dataSchema: StructType,
requiredSchema: StructType,
options: Map[String, String]
): PartitionedFile => Iterator[InternalRow] = { file: PartitionedFile =>
{
Expand All @@ -456,12 +461,16 @@ object OGRFileFormat extends Serializable {
val layer = dataset.GetLayerByName(resolvedLayerName)
layer.ResetReading()
val metadata = layer.GetMetadata_Dict().toMap
val mask = dataSchema.map(_.name).map(requiredSchema.fieldNames.contains(_)).toArray

var feature: Feature = null
(0 until layer.GetFeatureCount().toInt)
.foldLeft(Seq.empty[InternalRow])((acc, _) => {
feature = layer.GetNextFeature()
val fields = getFeatureFields(feature, schema, asWKB)
val fields = getFeatureFields(feature, dataSchema, asWKB)
.zip(mask)
.filter(_._2)
.map(_._1)
val values = fields ++ Seq(metadata)
val row = Utils.createRow(values)
acc ++ Seq(row)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OpenGeoDBFileFormat extends OGRFileFormat with Serializable {
options: Map[String, String],
hadoopConf: Configuration
): PartitionedFile => Iterator[InternalRow] = {
OGRFileFormat.buildReaderImpl(driverName, dataSchema, options)
OGRFileFormat.buildReaderImpl(driverName, dataSchema, requiredSchema, options)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ShapefileFileFormat extends OGRFileFormat with DataSourceRegister with Ser
): PartitionedFile => Iterator[InternalRow] =
(file: PartitionedFile) => {
if (checkExtension(file.filePath)) {
OGRFileFormat.buildReaderImpl(driverName, dataSchema, options)(file)
OGRFileFormat.buildReaderImpl(driverName, dataSchema, requiredSchema, options)(file)
} else {
Seq.empty[InternalRow].iterator
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ object Utils {
def getCleanPath(path: String, useZipPath: Boolean): String = {
val cleanPath = path.replace("file:/", "/").replace("dbfs:/", "/dbfs/")
if (useZipPath && cleanPath.endsWith(".zip")) {
// It is really important that the resulting path is /vsizip// and not /vsizip/
// /vsizip// is for absolute paths /viszip/ is relative to the current working directory
// /vsizip/ wont work on a cluster
// see: https://gdal.org/user/virtual_file_systems.html#vsizip-zip-archives
s"/vsizip/$cleanPath"
} else {
cleanPath
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, Wi
import com.databricks.labs.mosaic.expressions.geometry.base.UnaryVectorExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -52,4 +52,19 @@ object ST_Centroid extends WithExpressionInfo {
GenericExpressionFactory.getBaseBuilder[ST_Centroid](1, expressionConfig)
}

def legacyInfo(database: Option[String], name: String): ExpressionInfo =
new ExpressionInfo(
this.getClass.getCanonicalName,
database.getOrElse(this.database.orNull),
name,
usage,
"",
example,
"",
group,
"1.0",
"",
"built-in"
)

}
Loading

0 comments on commit a9fcc5e

Please sign in to comment.