Skip to content

Commit

Permalink
Loader: Improve management of temporary credentials (close #1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
istreeter authored and pondzix committed Mar 8, 2023
1 parent bcb9ada commit 7f9b2eb
Show file tree
Hide file tree
Showing 15 changed files with 276 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ object Databricks {
override def getLoadStatements(
discovery: DataDiscovery,
eventTableColumns: EventTableColumns,
loadAuthMethod: LoadAuthMethod,
i: Unit
): LoadStatements = {
val toCopy = columnsToCopyFromDiscoveredData(discovery)
val toSkip = ColumnsToSkip(getEntityColumnsPresentInDbOnly(eventTableColumns, toCopy))

NonEmptyList.one(
NonEmptyList.one(loadAuthMethod =>
Statement.EventsCopy(discovery.base, discovery.compression, toCopy, toSkip, discovery.typesInfo, loadAuthMethod, i)
)
}
Expand Down Expand Up @@ -218,7 +217,7 @@ object Databricks {
loadAuthMethod match {
case LoadAuthMethod.NoCreds =>
Fragment.empty
case LoadAuthMethod.TempCreds(awsAccessKey, awsSecretKey, awsSessionToken) =>
case LoadAuthMethod.TempCreds(awsAccessKey, awsSecretKey, awsSessionToken, _) =>
Fragment.const0(
s"WITH ( CREDENTIAL (AWS_ACCESS_KEY = '$awsAccessKey', AWS_SECRET_KEY = '$awsSecretKey', AWS_SESSION_TOKEN = '$awsSessionToken') )"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.snowplowanalytics.snowplow.rdbloader.db.{Statement, Target}
import com.snowplowanalytics.snowplow.rdbloader.cloud.LoadAuthService.LoadAuthMethod
import com.snowplowanalytics.snowplow.rdbloader.ConfigSpec._

import java.time.Instant
import scala.concurrent.duration.DurationInt
import org.specs2.mutable.Specification

Expand Down Expand Up @@ -65,30 +66,33 @@ class DatabricksSpec extends Specification {
)
)

target.getLoadStatements(discovery, eventsColumns, LoadAuthMethod.NoCreds, ()) should be like {
case NonEmptyList(Statement.EventsCopy(path, compression, columnsToCopy, columnsToSkip, _, _, _), Nil) =>
path must beEqualTo(baseFolder)
compression must beEqualTo(Compression.Gzip)

columnsToCopy.names must contain(
allOf(
ColumnName("unstruct_event_com_acme_aaa_1"),
ColumnName("unstruct_event_com_acme_ccc_1"),
ColumnName("contexts_com_acme_yyy_1"),
ColumnName("contexts_com_acme_zzz_1")
)
val results = target
.getLoadStatements(discovery, eventsColumns, ())
.map(f => f(LoadAuthMethod.NoCreds))

results should be like { case NonEmptyList(Statement.EventsCopy(path, compression, columnsToCopy, columnsToSkip, _, _, _), Nil) =>
path must beEqualTo(baseFolder)
compression must beEqualTo(Compression.Gzip)

columnsToCopy.names must contain(
allOf(
ColumnName("unstruct_event_com_acme_aaa_1"),
ColumnName("unstruct_event_com_acme_ccc_1"),
ColumnName("contexts_com_acme_yyy_1"),
ColumnName("contexts_com_acme_zzz_1")
)
)

columnsToCopy.names must not contain (ColumnName("unstruct_event_com_acme_bbb_1"))
columnsToCopy.names must not contain (ColumnName("contexts_com_acme_xxx_1"))
columnsToCopy.names must not contain (ColumnName("not_a_snowplow_column"))
columnsToCopy.names must not contain (ColumnName("unstruct_event_com_acme_bbb_1"))
columnsToCopy.names must not contain (ColumnName("contexts_com_acme_xxx_1"))
columnsToCopy.names must not contain (ColumnName("not_a_snowplow_column"))

columnsToSkip.names must beEqualTo(
List(
ColumnName("unstruct_event_com_acme_bbb_1"),
ColumnName("contexts_com_acme_xxx_1")
)
columnsToSkip.names must beEqualTo(
List(
ColumnName("unstruct_event_com_acme_bbb_1"),
ColumnName("contexts_com_acme_xxx_1")
)
)
}
}
}
Expand Down Expand Up @@ -140,7 +144,7 @@ class DatabricksSpec extends Specification {
ColumnName("contexts_com_acme_yyy_1")
)
)
val loadAuthMethod = LoadAuthMethod.TempCreds("testAccessKey", "testSecretKey", "testSessionToken")
val loadAuthMethod = LoadAuthMethod.TempCreds("testAccessKey", "testSecretKey", "testSessionToken", Instant.now.plusSeconds(3600))
val statement =
Statement.EventsCopy(baseFolder, Compression.Gzip, toCopy, toSkip, TypesInfo.WideRow(PARQUET, List.empty), loadAuthMethod, ())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ object Loader {
F[_]: Transaction[
*[_],
C
]: Concurrent: BlobStorage: Queue.Consumer: Clock: Iglu: Cache: Logging: Timer: Monitoring: ContextShift: LoadAuthService: JsonPathDiscovery,
C[_]: DAO: MonadThrow: Logging,
]: Concurrent: BlobStorage: Queue.Consumer: Clock: Iglu: Cache: Logging: Timer: Monitoring: ContextShift: JsonPathDiscovery,
C[_]: DAO: MonadThrow: Logging: LoadAuthService,
I
](
config: Config[StorageTarget],
Expand Down Expand Up @@ -150,8 +150,8 @@ object Loader {
F[_]: Transaction[
*[_],
C
]: Concurrent: BlobStorage: Queue.Consumer: Iglu: Cache: Logging: Timer: Monitoring: ContextShift: LoadAuthService: JsonPathDiscovery,
C[_]: DAO: MonadThrow: Logging,
]: Concurrent: BlobStorage: Queue.Consumer: Iglu: Cache: Logging: Timer: Monitoring: ContextShift: JsonPathDiscovery,
C[_]: DAO: MonadThrow: Logging: LoadAuthService,
I
](
config: Config[StorageTarget],
Expand All @@ -176,8 +176,8 @@ object Loader {
* actions, instead of whole `Control` object
*/
private def processDiscovery[
F[_]: Transaction[*[_], C]: Concurrent: Iglu: Logging: Timer: Monitoring: ContextShift: LoadAuthService,
C[_]: DAO: MonadThrow: Logging,
F[_]: Transaction[*[_], C]: Concurrent: Iglu: Logging: Timer: Monitoring: ContextShift,
C[_]: DAO: MonadThrow: Logging: LoadAuthService,
I
](
config: Config[StorageTarget],
Expand All @@ -201,8 +201,7 @@ object Loader {
for {
start <- Clock[F].instantNow
_ <- discovery.origin.timestamps.min.map(t => Monitoring[F].periodicMetrics.setEarliestKnownUnloadedData(t)).sequence.void
loadAuth <- LoadAuthService[F].getLoadAuthMethod(config.storage.eventsLoadAuthMethod)
result <- Load.load[F, C, I](config, setStageC, control.incrementAttempts, discovery, loadAuth, initQueryResult, target)
result <- Load.load[F, C, I](config, setStageC, control.incrementAttempts, discovery, initQueryResult, target)
attempts <- control.getAndResetAttempts
_ <- result match {
case Right(ingested) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,28 @@
*/
package com.snowplowanalytics.snowplow.rdbloader.cloud

import cats.{Applicative, ~>}
import cats.effect._
import cats.effect.concurrent.Ref
import cats.implicits._
import com.snowplowanalytics.snowplow.rdbloader.common.cloud.{Utils => CloudUtils}
import com.snowplowanalytics.snowplow.rdbloader.config.StorageTarget
import com.snowplowanalytics.snowplow.rdbloader.config.{Config, StorageTarget}
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.sts.StsAsyncClient
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest

import java.time.Instant
import scala.concurrent.duration.FiniteDuration

trait LoadAuthService[F[_]] {
def getLoadAuthMethod(authMethodConfig: StorageTarget.LoadAuthMethod): F[LoadAuthService.LoadAuthMethod]
trait LoadAuthService[F[_]] { self =>
def forLoadingEvents: F[LoadAuthService.LoadAuthMethod]
def forFolderMonitoring: F[LoadAuthService.LoadAuthMethod]

def mapK[G[_]](arrow: F ~> G): LoadAuthService[G] =
new LoadAuthService[G] {
def forLoadingEvents: G[LoadAuthService.LoadAuthMethod] = arrow(self.forLoadingEvents)
def forFolderMonitoring: G[LoadAuthService.LoadAuthMethod] = arrow(self.forFolderMonitoring)
}
}

object LoadAuthService {
Expand All @@ -48,54 +58,121 @@ object LoadAuthService {
final case class TempCreds(
awsAccessKey: String,
awsSecretKey: String,
awsSessionToken: String
awsSessionToken: String,
expires: Instant
) extends LoadAuthMethod
}

private trait LoadAuthMethodProvider[F[_]] {
def get: F[LoadAuthService.LoadAuthMethod]
}

/**
* Get load auth method according to value specified in the config If temporary credentials method
* is specified in the config, it will get temporary credentials with sending request to STS
* service then return credentials.
*/
def aws[F[_]: Concurrent: ContextShift](region: String, sessionDuration: FiniteDuration): Resource[F, LoadAuthService[F]] =
def aws[F[_]: Concurrent: ContextShift: Clock](
region: String,
timeouts: Config.Timeouts,
eventsLoadAuthMethodConfig: StorageTarget.LoadAuthMethod,
foldersLoadAuthMethodConfig: StorageTarget.LoadAuthMethod
): Resource[F, LoadAuthService[F]] =
(eventsLoadAuthMethodConfig, foldersLoadAuthMethodConfig) match {
case (StorageTarget.LoadAuthMethod.NoCreds, StorageTarget.LoadAuthMethod.NoCreds) =>
noop[F]
case (_, _) =>
for {
stsAsyncClient <- Resource.fromAutoCloseable(
Concurrent[F].delay(
StsAsyncClient
.builder()
.region(Region.of(region))
.build()
)
)
eventsAuthProvider <- Resource.eval(awsCreds(stsAsyncClient, timeouts.loading, eventsLoadAuthMethodConfig))
foldersAuthProvider <- Resource.eval(awsCreds(stsAsyncClient, timeouts.nonLoading, foldersLoadAuthMethodConfig))
} yield new LoadAuthService[F] {
override def forLoadingEvents: F[LoadAuthMethod] =
eventsAuthProvider.get
override def forFolderMonitoring: F[LoadAuthMethod] =
foldersAuthProvider.get
}
}

private def awsCreds[F[_]: Concurrent: ContextShift: Clock](
client: StsAsyncClient,
usageDuration: FiniteDuration,
loadAuthConfig: StorageTarget.LoadAuthMethod
): F[LoadAuthMethodProvider[F]] =
loadAuthConfig match {
case StorageTarget.LoadAuthMethod.NoCreds =>
Concurrent[F].pure {
new LoadAuthMethodProvider[F] {
def get: F[LoadAuthService.LoadAuthMethod] = Concurrent[F].pure(LoadAuthMethod.NoCreds)
}
}
case tc: StorageTarget.LoadAuthMethod.TempCreds =>
awsTempCreds(client, usageDuration, tc)
}

/**
* Either fetches new temporary credentials from STS, or returns cached temporary credentials if
* they are still valid
*
* The new credentials are valid for *twice* the length of time they requested for. This means
* there is a high chance we can re-use the cached credentials later.
*
* @param client
* Used to fetch new credentials
* @param usageDuration
* How long these credentials must be valid for. If cached credentials do not cover this
* duration, then new creds are needed.
* @param tempCredsConfig
* Configuration required for the STS request.
*/
private def awsTempCreds[F[_]: Concurrent: ContextShift: Clock](
client: StsAsyncClient,
usageDuration: FiniteDuration,
tempCredsConfig: StorageTarget.LoadAuthMethod.TempCreds
): F[LoadAuthMethodProvider[F]] =
for {
stsAsyncClient <- Resource.fromAutoCloseable(
Concurrent[F].delay(
StsAsyncClient
.builder()
.region(Region.of(region))
.build()
)
)
authService = new LoadAuthService[F] {
override def getLoadAuthMethod(authMethodConfig: StorageTarget.LoadAuthMethod): F[LoadAuthMethod] =
authMethodConfig match {
case StorageTarget.LoadAuthMethod.NoCreds => Concurrent[F].pure(LoadAuthMethod.NoCreds)
case StorageTarget.LoadAuthMethod.TempCreds(roleArn, roleSessionName) =>
for {
assumeRoleRequest <- Concurrent[F].delay(
AssumeRoleRequest
.builder()
.durationSeconds(sessionDuration.toSeconds.toInt)
.roleArn(roleArn)
.roleSessionName(roleSessionName)
.build()
)
response <- CloudUtils.fromCompletableFuture(
Concurrent[F].delay(stsAsyncClient.assumeRole(assumeRoleRequest))
)
creds = response.credentials()
} yield LoadAuthMethod.TempCreds(creds.accessKeyId(), creds.secretAccessKey(), creds.sessionToken())
}
}
} yield authService
ref <- Ref.of(Option.empty[LoadAuthMethod.TempCreds])
} yield new LoadAuthMethodProvider[F] {
override def get: F[LoadAuthMethod] =
for {
opt <- ref.get
now <- Clock[F].instantNow
next <- opt match {
case Some(tc) if tc.expires.isAfter(now.plusMillis(usageDuration.toMillis)) =>
Concurrent[F].pure(tc)
case _ =>
for {
assumeRoleRequest <- Concurrent[F].delay(
AssumeRoleRequest
.builder()
// 900 is the minimum value accepted by the AssumeRole API
.durationSeconds(Math.max(usageDuration.toSeconds.toInt, 900))
.roleArn(tempCredsConfig.roleArn)
.roleSessionName(tempCredsConfig.roleSessionName)
.build()
)
response <- CloudUtils.fromCompletableFuture(
Concurrent[F].delay(client.assumeRole(assumeRoleRequest))
)
creds = response.credentials()
} yield LoadAuthMethod.TempCreds(creds.accessKeyId, creds.secretAccessKey, creds.sessionToken, creds.expiration)
}
_ <- ref.set(Some(next))
} yield next
}

def noop[F[_]: Concurrent]: Resource[F, LoadAuthService[F]] =
def noop[F[_]: Applicative]: Resource[F, LoadAuthService[F]] =
Resource.pure[F, LoadAuthService[F]](new LoadAuthService[F] {
override def getLoadAuthMethod(authMethodConfig: StorageTarget.LoadAuthMethod): F[LoadAuthMethod] =
authMethodConfig match {
case StorageTarget.LoadAuthMethod.NoCreds => Concurrent[F].pure(LoadAuthMethod.NoCreds)
case _ => Concurrent[F].raiseError(new Exception("No auth service is given to resolve credentials."))
}
override def forLoadingEvents: F[LoadAuthMethod] =
Applicative[F].pure(LoadAuthMethod.NoCreds)
override def forFolderMonitoring: F[LoadAuthMethod] =
Applicative[F].pure(LoadAuthMethod.NoCreds)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import com.snowplowanalytics.iglu.schemaddl.migrations.{Migration => SchemaMigra
import com.snowplowanalytics.snowplow.rdbloader.LoadStatements
import com.snowplowanalytics.snowplow.rdbloader.db.Columns.EventTableColumns
import com.snowplowanalytics.snowplow.rdbloader.db.Migration.Block
import com.snowplowanalytics.snowplow.rdbloader.cloud.LoadAuthService.LoadAuthMethod
import com.snowplowanalytics.snowplow.rdbloader.discovery.{DataDiscovery, ShreddedType}
import com.snowplowanalytics.snowplow.rdbloader.dsl.DAO
import doobie.Fragment
Expand Down Expand Up @@ -49,7 +48,6 @@ trait Target[I] {
def getLoadStatements(
discovery: DataDiscovery,
eventTableColumns: EventTableColumns,
loadAuthMethod: LoadAuthMethod,
initQueryResult: I
): LoadStatements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Environment[F[_], I](

implicit val daoC: DAO[ConnectionIO] = DAO.connectionIO(target, timeouts)
implicit val loggingC: Logging[ConnectionIO] = logging.mapK(transaction.arrowBack)
implicit val loadAuthServiceC: LoadAuthService[ConnectionIO] = loadAuthService.mapK(transaction.arrowBack)
val controlF: Control[F] = control
val telemetryF: Telemetry[F] = telemetry
val dbTarget: Target[I] = target
Expand Down Expand Up @@ -156,7 +157,9 @@ object Environment {
control.isBusy,
Some(postProcess)
)
loadAuthService <- LoadAuthService.aws[F](c.region.name, config.timeouts.loading)
loadAuthService <-
LoadAuthService
.aws[F](c.region.name, config.timeouts, config.storage.eventsLoadAuthMethod, config.storage.foldersLoadAuthMethod)
jsonPathDiscovery = JsonPathDiscovery.aws[F](c.region.name)
secretStore <- EC2ParameterStore.secretStore[F]
} yield CloudServices(blobStorage, queueConsumer, loadAuthService, jsonPathDiscovery, secretStore)
Expand Down
Loading

0 comments on commit 7f9b2eb

Please sign in to comment.