Skip to content

Commit

Permalink
Databricks loader: Support for generated columns (close #951)
Browse files Browse the repository at this point in the history
  • Loading branch information
istreeter authored and pondzix committed Jun 29, 2022
1 parent bb84f44 commit 07a98f2
Show file tree
Hide file tree
Showing 15 changed files with 373 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,58 +12,53 @@
*/
package com.snowplowanalytics.snowplow.loader.databricks

import java.sql.Timestamp
import cats.data.NonEmptyList
import io.circe.syntax._
import doobie.Fragment
import doobie.implicits.javasql._
import doobie.implicits._
import com.snowplowanalytics.iglu.core.SchemaKey
import com.snowplowanalytics.iglu.schemaddl.migrations.{Migration, SchemaList}
import com.snowplowanalytics.snowplow.rdbloader.LoadStatements
import com.snowplowanalytics.snowplow.rdbloader.config.{Config, StorageTarget}
import com.snowplowanalytics.snowplow.rdbloader.db.Columns.{ColumnsToCopy, ColumnsToSkip, EventTableColumns}
import com.snowplowanalytics.snowplow.rdbloader.db.Migration.{Block, Entity}
import com.snowplowanalytics.snowplow.rdbloader.db.{Statement, Target, AtomicColumns}
import com.snowplowanalytics.snowplow.rdbloader.db.{Manifest, Statement, Target}
import com.snowplowanalytics.snowplow.rdbloader.discovery.{DataDiscovery, ShreddedType}
import com.snowplowanalytics.snowplow.rdbloader.loading.EventsTable
import com.snowplowanalytics.snowplow.analytics.scalasdk.SnowplowEvent
import doobie.Fragment
import doobie.implicits._
import doobie.implicits.javasql._
import io.circe.syntax._

import java.sql.Timestamp

object Databricks {

val AlertingTempTableName = "rdb_folder_monitoring"
val ManifestName = "manifest"
val UnstructPrefix = "unstruct_event_"
val ContextsPrefix = "contexts_"

def build(config: Config[StorageTarget]): Either[String, Target] = {
config.storage match {
case tgt: StorageTarget.Databricks =>
val result = new Target {
def updateTable(migration: Migration): Block =
Block(Nil, Nil, Entity.Table(tgt.schema, SchemaKey(migration.vendor, migration.name, "jsonschema", migration.to)))

def extendTable(info: ShreddedType.Info): Option[Block] = None

override val requiresEventsColumns: Boolean = true

def getLoadStatements(discovery: DataDiscovery): LoadStatements =
NonEmptyList.one(Statement.EventsCopy(discovery.base, discovery.compression, getColumns(discovery)))
override def updateTable(migration: Migration): Block =
Block(Nil, Nil, Entity.Table(tgt.schema, SchemaKey(migration.vendor, migration.name, "jsonschema", migration.to)))

def getColumns(discovery: DataDiscovery): List[String] = {
val atomicColumns = AtomicColumns.Columns
val shredTypeColumns = discovery.shreddedTypes
.filterNot(_.isAtomic)
.map(getShredTypeColumn)
atomicColumns ::: shredTypeColumns
}
override def extendTable(info: ShreddedType.Info): Option[Block] = None

def getShredTypeColumn(shreddedType: ShreddedType): String = {
val shredProperty = shreddedType.getSnowplowEntity.toSdkProperty
val info = shreddedType.info
SnowplowEvent.transformSchema(shredProperty, info.vendor, info.name, info.model)
override def getLoadStatements(discovery: DataDiscovery, eventTableColumns: EventTableColumns): LoadStatements = {
val toCopy = ColumnsToCopy.fromDiscoveredData(discovery)
val toSkip = ColumnsToSkip(getEntityColumnsPresentInDbOnly(eventTableColumns, toCopy))

NonEmptyList.one(Statement.EventsCopy(discovery.base, discovery.compression, toCopy, toSkip))
}

def createTable(schemas: SchemaList): Block = Block(Nil, Nil, Entity.Table(tgt.schema, schemas.latest.schemaKey))
override def createTable(schemas: SchemaList): Block = Block(Nil, Nil, Entity.Table(tgt.schema, schemas.latest.schemaKey))

def getManifest: Statement =
override def getManifest: Statement =
Statement.CreateTable(
Fragment.const0(s"""CREATE TABLE IF NOT EXISTS ${qualify(ManifestName)} (
Fragment.const0(s"""CREATE TABLE IF NOT EXISTS ${qualify(Manifest.Name)} (
| base VARCHAR(512) NOT NULL,
| types VARCHAR(65535) NOT NULL,
| shredding_started TIMESTAMP NOT NULL,
Expand All @@ -79,7 +74,7 @@ object Databricks {
|""".stripMargin)
)

def toFragment(statement: Statement): Fragment =
override def toFragment(statement: Statement): Fragment =
statement match {
case Statement.Select1 => sql"SELECT 1"
case Statement.ReadyCheck => sql"SELECT 1"
Expand All @@ -93,18 +88,23 @@ object Databricks {
sql"DROP TABLE IF EXISTS $frTableName"
case Statement.FoldersMinusManifest =>
val frTableName = Fragment.const(qualify(AlertingTempTableName))
val frManifest = Fragment.const(qualify(ManifestName))
val frManifest = Fragment.const(qualify(Manifest.Name))
sql"SELECT run_id FROM $frTableName MINUS SELECT base FROM $frManifest"
case Statement.FoldersCopy(source) =>
val frTableName = Fragment.const(qualify(AlertingTempTableName))
val frPath = Fragment.const0(source)
sql"""COPY INTO $frTableName
FROM (SELECT _C0::VARCHAR(512) RUN_ID FROM '$frPath')
FILEFORMAT = CSV""";
case Statement.EventsCopy(path, _, columns) =>
val frTableName = Fragment.const(qualify(EventsTable.MainName))
val frPath = Fragment.const0(s"$path/output=good")
val frSelectColumns = Fragment.const0(columns.mkString(",") + ", current_timestamp() as load_tstamp")
case Statement.EventsCopy(path, _, toCopy, toSkip) =>
val frTableName = Fragment.const(qualify(EventsTable.MainName))
val frPath = Fragment.const0(path.append("output=good"))
val nonNulls = toCopy.names.map(_.value)
val nulls = toSkip.names.map(c => s"NULL AS ${c.value}")
val currentTimestamp = "current_timestamp() AS load_tstamp"
val allColumns = (nonNulls ::: nulls) :+ currentTimestamp

val frSelectColumns = Fragment.const0(allColumns.mkString(","))
sql"""COPY INTO $frTableName
FROM (
SELECT $frSelectColumns from '$frPath'
Expand All @@ -123,12 +123,11 @@ object Databricks {
throw new IllegalStateException("Databricks Loader does not support migrations")
case _: Statement.RenameTable =>
throw new IllegalStateException("Databricks Loader does not support migrations")
case Statement.SetSearchPath =>
throw new IllegalStateException("Databricks Loader does not support migrations")
case _: Statement.GetColumns =>
throw new IllegalStateException("Databricks Loader does not support migrations")
case Statement.GetColumns(tableName) =>
val qualifiedName = Fragment.const(qualify(tableName))
sql"SHOW columns in $qualifiedName"
case Statement.ManifestAdd(message) =>
val tableName = Fragment.const(qualify(ManifestName))
val tableName = Fragment.const(qualify(Manifest.Name))
val types = message.types.asJson.noSpaces
sql"""INSERT INTO $tableName
(base, types, shredding_started, shredding_completed,
Expand All @@ -148,7 +147,7 @@ object Databricks {
base, types, shredding_started, shredding_completed,
min_collector_tstamp, max_collector_tstamp,
compression, processor_artifact, processor_version, count_good
FROM ${Fragment.const0(qualify(ManifestName))} WHERE base = $base"""
FROM ${Fragment.const0(qualify(Manifest.Name))} WHERE base = $base"""
case Statement.AddLoadTstampColumn =>
throw new IllegalStateException("Databricks Loader does not support load_tstamp column")
case Statement.CreateTable(ddl) =>
Expand All @@ -162,14 +161,18 @@ object Databricks {
throw new IllegalStateException("Databricks Loader does not support migrations")
}

def qualify(tableName: String): String =
private def qualify(tableName: String): String =
s"${tgt.catalog}.${tgt.schema}.$tableName"
}
Right(result)
case other =>
Left(s"Invalid State: trying to build Databricks interpreter with unrecognized config (${other.driver} driver)")
}

}

private def getEntityColumnsPresentInDbOnly(eventTableColumns: EventTableColumns, toCopy: ColumnsToCopy) = {
eventTableColumns
.filter(name => name.value.startsWith(UnstructPrefix) || name.value.startsWith(ContextsPrefix))
.diff(toCopy.names)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (c) 2012-2022 Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Apache License Version 2.0,
* and you may not use this file except in compliance with the Apache License Version 2.0.
* You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the Apache License Version 2.0 is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the Apache License Version 2.0 for the specific language governing permissions and limitations there under.
*/
package com.snowplowanalytics.snowplow.loader.databricks

import cats.data.NonEmptyList
import com.snowplowanalytics.snowplow.rdbloader.common.S3
import com.snowplowanalytics.snowplow.rdbloader.discovery.{DataDiscovery, ShreddedType}
import com.snowplowanalytics.snowplow.rdbloader.common.config.TransformerConfig.Compression
import com.snowplowanalytics.snowplow.rdbloader.common.config.Region
import com.snowplowanalytics.snowplow.rdbloader.common.LoaderMessage.SnowplowEntity
import com.snowplowanalytics.snowplow.rdbloader.config.{Config, StorageTarget}
import com.snowplowanalytics.snowplow.rdbloader.db.Columns.{ColumnName, ColumnsToCopy, ColumnsToSkip}
import com.snowplowanalytics.snowplow.rdbloader.db.{Statement, Target}

import scala.concurrent.duration.DurationInt
import org.specs2.mutable.Specification


class DatabricksSpec extends Specification {
import DatabricksSpec._

"getLoadStatements" should {

"create LoadStatements with columns to copy and columns to skip" in {

val eventsColumns = List(
"unstruct_event_com_acme_aaa_1",
"unstruct_event_com_acme_bbb_1",
"contexts_com_acme_xxx_1",
"contexts_com_acme_yyy_1",
"not_a_snowplow_column"
).map(ColumnName)

val shreddedTypes = List(
ShreddedType.Widerow(ShreddedType.Info(baseFolder, "com_acme", "aaa", 1, SnowplowEntity.SelfDescribingEvent)),
ShreddedType.Widerow(ShreddedType.Info(baseFolder, "com_acme", "ccc", 1, SnowplowEntity.SelfDescribingEvent)),
ShreddedType.Widerow(ShreddedType.Info(baseFolder, "com_acme", "yyy", 1, SnowplowEntity.Context)),
ShreddedType.Widerow(ShreddedType.Info(baseFolder, "com_acme", "zzz", 1, SnowplowEntity.Context))
)

val discovery = DataDiscovery(baseFolder, shreddedTypes, Compression.Gzip)

target.getLoadStatements(discovery, eventsColumns) should be like {
case NonEmptyList(Statement.EventsCopy(path, compression, columnsToCopy, columnsToSkip), Nil) =>
path must beEqualTo(baseFolder)
compression must beEqualTo(Compression.Gzip)

columnsToCopy.names must contain(allOf(
ColumnName("unstruct_event_com_acme_aaa_1"),
ColumnName("unstruct_event_com_acme_ccc_1"),
ColumnName("contexts_com_acme_yyy_1"),
ColumnName("contexts_com_acme_zzz_1"),
))

columnsToCopy.names must not contain(ColumnName("unstruct_event_com_acme_bbb_1"))
columnsToCopy.names must not contain(ColumnName("contexts_com_acme_xxx_1"))
columnsToCopy.names must not contain(ColumnName("not_a_snowplow_column"))

columnsToSkip.names must beEqualTo(List(
ColumnName("unstruct_event_com_acme_bbb_1"),
ColumnName("contexts_com_acme_xxx_1"),
))
}
}
}

"toFragment" should {
"create sql for loading" in {
val toCopy = ColumnsToCopy(List(
ColumnName("app_id"),
ColumnName("unstruct_event_com_acme_aaa_1"),
ColumnName("contexts_com_acme_xxx_1")
))
val toSkip = ColumnsToSkip(List(
ColumnName("unstruct_event_com_acme_bbb_1"),
ColumnName("contexts_com_acme_yyy_1"),
))
val statement = Statement.EventsCopy(baseFolder, Compression.Gzip, toCopy, toSkip)

target.toFragment(statement).toString must beLike { case sql =>
sql must contain("SELECT app_id,unstruct_event_com_acme_aaa_1,contexts_com_acme_xxx_1,NULL AS unstruct_event_com_acme_bbb_1,NULL AS contexts_com_acme_yyy_1,current_timestamp() AS load_tstamp from 's3://somewhere/path/output=good/'")
}
}
}
}

object DatabricksSpec {

val baseFolder: S3.Folder =
S3.Folder.coerce("s3://somewhere/path")

val target: Target = Databricks.build(Config(
Region("eu-central-1"),
None,
Config.Monitoring(None, None, Config.Metrics(None, None, 1.minute), None, None, None),
"my-queue.fifo",
None,
StorageTarget.Databricks(
"host",
"hive_metastore",
"snowplow",
443,
"some/path",
StorageTarget.PasswordConfig.PlainText("xxx"),
None,
"useragent"
),
Config.Schedules(Nil),
Config.Timeouts(1.minute, 1.minute, 1.minute),
Config.Retries(Config.Strategy.Constant, None, 1.minute, None),
Config.Retries(Config.Strategy.Constant, None, 1.minute, None)
)).right.get

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import cats.implicits._
import cats.effect.{Clock, Concurrent, MonadThrow, Timer}
import fs2.Stream
import com.snowplowanalytics.snowplow.rdbloader.config.{Config, StorageTarget}
import com.snowplowanalytics.snowplow.rdbloader.db.Columns._
import com.snowplowanalytics.snowplow.rdbloader.db.{AtomicColumns, HealthCheck, Manifest, Statement, Control => DbControl}
import com.snowplowanalytics.snowplow.rdbloader.discovery.{DataDiscovery, NoOperation, Retries}
import com.snowplowanalytics.snowplow.rdbloader.dsl.{AWS, Cache, DAO, FolderMonitoring, Iglu, Logging, Monitoring, StateMonitoring, Transaction}
Expand Down Expand Up @@ -87,8 +88,8 @@ object Loader {
* A primary loading processing, pulling information from discovery streams
* (SQS and retry queue) and performing the load operation itself
*/
def loadStream[F[_]: Transaction[*[_], C]: Concurrent: AWS: Iglu: Cache: Logging: Timer: Monitoring,
C[_]: DAO: MonadThrow: Logging](config: Config[StorageTarget], control: Control[F]): Stream[F, Unit] = {
private def loadStream[F[_]: Transaction[*[_], C]: Concurrent: AWS: Iglu: Cache: Logging: Timer: Monitoring,
C[_]: DAO: MonadThrow: Logging](config: Config[StorageTarget], control: Control[F]): Stream[F, Unit] = {
val sqsDiscovery: DiscoveryStream[F] =
DataDiscovery.discover[F](config, control.incrementMessages, control.isBusy)
val retryDiscovery: DiscoveryStream[F] =
Expand All @@ -105,9 +106,9 @@ object Loader {
* over to `Load`. A primary function handling the global state - everything
* downstream has access only to `F` actions, instead of whole `Control` object
*/
def processDiscovery[F[_]: Transaction[*[_], C]: Concurrent: Iglu: Logging: Timer: Monitoring,
C[_]: DAO: MonadThrow: Logging](config: Config[StorageTarget], control: Control[F])
(discovery: DataDiscovery.WithOrigin): F[Unit] = {
private def processDiscovery[F[_]: Transaction[*[_], C]: Concurrent: Iglu: Logging: Timer: Monitoring,
C[_]: DAO: MonadThrow: Logging](config: Config[StorageTarget], control: Control[F])
(discovery: DataDiscovery.WithOrigin): F[Unit] = {
val folder = discovery.origin.base
val busy = (control.makeBusy: MakeBusy[F]).apply(folder)
val backgroundCheck: F[Unit] => F[Unit] =
Expand Down Expand Up @@ -142,22 +143,28 @@ object Loader {
loading.handleErrorWith(reportLoadFailure[F](discovery, addFailure))
}

def addLoadTstampColumn[F[_]: DAO: Monad: Logging](targetConfig: StorageTarget): F[Unit] =
private def addLoadTstampColumn[F[_]: DAO: Monad: Logging](targetConfig: StorageTarget): F[Unit] =
targetConfig match {
// Adding load_tstamp column explicitly is not needed due to merge schema
// feature of Databricks. It will create missing column itself.
case _: StorageTarget.Databricks => Monad[F].unit
case _ =>
for {
columns <- DbControl.getColumns[F](EventsTable.MainName)
_ <- if (columns.map(_.toLowerCase).contains(AtomicColumns.ColumnsWithDefault.LoadTstamp))
allColumns <- DbControl.getColumns[F](EventsTable.MainName)
_ <- if (loadTstampColumnExist(allColumns))
Logging[F].info("load_tstamp column already exists")
else
DAO[F].executeUpdate(Statement.AddLoadTstampColumn, DAO.Purpose.NonLoading).void *>
Logging[F].info("load_tstamp column is added successfully")
} yield ()
}

private def loadTstampColumnExist(eventTableColumns: EventTableColumns) = {
eventTableColumns
.map(_.value.toLowerCase)
.contains(AtomicColumns.ColumnsWithDefault.LoadTstamp.value)
}

/**
* Handle a failure during loading.
* `Load.getTransaction` can fail only in one "expected" way - if the folder is already loaded
Expand All @@ -168,9 +175,9 @@ object Loader {
* @param discovery the original discovery
* @param error the actual error, typically `SQLException`
*/
def reportLoadFailure[F[_]: Logging: Monitoring: Monad](discovery: DataDiscovery.WithOrigin,
addFailure: Throwable => F[Boolean])
(error: Throwable): F[Unit] = {
private def reportLoadFailure[F[_]: Logging: Monitoring: Monad](discovery: DataDiscovery.WithOrigin,
addFailure: Throwable => F[Boolean])
(error: Throwable): F[Unit] = {
val message = error match {
case e: SQLException => s"${error.getMessage} - SqlState: ${e.getSQLState}"
case _ => Option(error.getMessage).getOrElse(error.toString)
Expand All @@ -184,7 +191,7 @@ object Loader {
}

/** Last level of failure handling, called when non-loading stream fail. Called on an application crash */
def reportFatal[F[_]: Apply: Logging: Monitoring]: PartialFunction[Throwable, F[Unit]] = {
private def reportFatal[F[_]: Apply: Logging: Monitoring]: PartialFunction[Throwable, F[Unit]] = {
case error =>
Logging[F].error("Loader shutting down") *>
Monitoring[F].alert(Monitoring.AlertPayload.error(error.toString)) *>
Expand Down
Loading

0 comments on commit 07a98f2

Please sign in to comment.