Skip to content

Commit

Permalink
Migrate auth token code to cats effect
Browse files Browse the repository at this point in the history
  • Loading branch information
shinyhappydan committed Oct 2, 2023
1 parent eb9fde2 commit e76ce0a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ch.epfl.bluebrain.nexus.delta.wiring

import cats.effect.{Clock, IO}
import ch.epfl.bluebrain.nexus.delta.Main.pluginsMaxPriority
import ch.epfl.bluebrain.nexus.delta.config.AppConfig
import ch.epfl.bluebrain.nexus.delta.kernel.cache.CacheConfig
Expand Down Expand Up @@ -34,8 +35,8 @@ object IdentitiesModule extends ModuleDef {
new OpenIdAuthService(httpClient, realms)
}

make[AuthTokenProvider].fromEffect { (authService: OpenIdAuthService) =>
AuthTokenProvider(authService)
make[AuthTokenProvider].fromEffect { (authService: OpenIdAuthService, clock: Clock[IO]) =>
AuthTokenProvider(authService)(clock)
}

many[RemoteContextResolution].addEffect(ContextValue.fromFile("contexts/identities.json").map { ctx =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import akka.http.scaladsl.model.Uri.Query
import akka.http.scaladsl.model.headers.{`Last-Event-ID`, Accept}
import akka.http.scaladsl.model.{HttpRequest, HttpResponse, StatusCodes}
import akka.stream.alpakka.sse.scaladsl.EventSource
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.plugins.compositeviews.model.CompositeViewSource.RemoteProjectSource
import ch.epfl.bluebrain.nexus.delta.plugins.compositeviews.stream.CompositeBranch
import ch.epfl.bluebrain.nexus.delta.rdf.IriOrBNode.Iri
Expand Down Expand Up @@ -87,11 +88,12 @@ object DeltaClient {
)(implicit
as: ActorSystem[Nothing],
scheduler: Scheduler
) extends DeltaClient {
) extends DeltaClient
with MigrateEffectSyntax {

override def projectStatistics(source: RemoteProjectSource): HttpResult[ProjectStatistics] = {
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
request =
Get(
source.endpoint / "projects" / source.project.organization.value / source.project.project.value / "statistics"
Expand All @@ -104,7 +106,7 @@ object DeltaClient {

override def remaining(source: RemoteProjectSource, offset: Offset): HttpResult[RemainingElems] = {
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
request = Get(elemAddress(source) / "remaining")
.addHeader(accept)
.addHeader(`Last-Event-ID`(offset.value.toString))
Expand All @@ -115,7 +117,7 @@ object DeltaClient {

override def checkElems(source: RemoteProjectSource): HttpResult[Unit] = {
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
result <- client(Head(elemAddress(source)).withCredentials(authToken)) {
case resp if resp.status.isSuccess() => UIO.delay(resp.discardEntityBytes()) >> IO.unit
}
Expand All @@ -130,7 +132,7 @@ object DeltaClient {

def send(request: HttpRequest): Future[HttpResponse] = {
(for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
result <- client[HttpResponse](request.withCredentials(authToken))(IO.pure(_))
} yield result).runToFuture
}
Expand Down Expand Up @@ -164,7 +166,7 @@ object DeltaClient {
val resourceUrl =
source.endpoint / "resources" / source.project.organization.value / source.project.project.value / "_" / id.toString
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
req = Get(
source.resourceTag.fold(resourceUrl)(t => resourceUrl.withQuery(Query("tag" -> t.value)))
).addHeader(Accept(RdfMediaTypes.`application/n-quads`)).withCredentials(authToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import akka.http.scaladsl.model.Multipart.FormData
import akka.http.scaladsl.model.Multipart.FormData.BodyPart
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model.Uri.Path
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.StorageFileRejection.FetchFileRejection.UnexpectedFetchError
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.StorageFileRejection.MoveFileRejection.UnexpectedMoveError
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.StorageFileRejection.{FetchFileRejection, MoveFileRejection, SaveFileRejection}
Expand Down Expand Up @@ -34,7 +35,7 @@ import scala.concurrent.duration._
*/
final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenProvider, credentials: Credentials)(
implicit as: ActorSystem
) {
) extends MigrateEffectSyntax {
import as.dispatcher

private val serviceName = Name.unsafe("remoteStorage")
Expand All @@ -58,7 +59,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
* the storage bucket name
*/
def exists(bucket: Label)(implicit baseUri: BaseUri): IO[HttpClientError, Unit] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value
val req = Head(endpoint).withCredentials(authToken)
client(req) {
Expand All @@ -82,7 +83,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
relativePath: Path,
entity: BodyPartEntity
)(implicit baseUri: BaseUri): IO[SaveFileRejection, RemoteDiskStorageFileAttributes] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "files" / relativePath
val filename = relativePath.lastSegment.getOrElse("filename")
val multipartForm = FormData(BodyPart("file", entity, Map("filename" -> filename))).toEntity()
Expand All @@ -106,7 +107,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
* the relative path to the file location
*/
def getFile(bucket: Label, relativePath: Path)(implicit baseUri: BaseUri): IO[FetchFileRejection, AkkaSource] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "files" / relativePath
client.toDataBytes(Get(endpoint).withCredentials(authToken)).mapError {
case error @ HttpClientStatusError(_, `NotFound`, _) if !bucketNotFoundType(error) =>
Expand All @@ -129,7 +130,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
bucket: Label,
relativePath: Path
)(implicit baseUri: BaseUri): IO[FetchFileRejection, RemoteDiskStorageFileAttributes] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "attributes" / relativePath
client.fromJsonTo[RemoteDiskStorageFileAttributes](Get(endpoint).withCredentials(authToken)).mapError {
case error @ HttpClientStatusError(_, `NotFound`, _) if !bucketNotFoundType(error) =>
Expand All @@ -156,7 +157,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
sourceRelativePath: Path,
destRelativePath: Path
)(implicit baseUri: BaseUri): IO[MoveFileRejection, RemoteDiskStorageFileAttributes] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "files" / destRelativePath
val payload = Json.obj("source" -> sourceRelativePath.toString.asJson)
client.fromJsonTo[RemoteDiskStorageFileAttributes](Put(endpoint, payload).withCredentials(authToken)).mapError {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,38 @@
package ch.epfl.bluebrain.nexus.delta.sdk.auth

import cats.effect.Clock
import cats.effect.{Clock, IO}
import ch.epfl.bluebrain.nexus.delta.kernel.Logger
import ch.epfl.bluebrain.nexus.delta.kernel.cache.KeyValueStore
import ch.epfl.bluebrain.nexus.delta.kernel.cache.LocalCache
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.kernel.utils.IOUtils
import ch.epfl.bluebrain.nexus.delta.kernel.utils.IOInstant
import ch.epfl.bluebrain.nexus.delta.sdk.auth.Credentials.ClientCredentials
import ch.epfl.bluebrain.nexus.delta.sdk.identities.ParsedToken
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.AuthToken
import monix.bio.UIO
import monix.bio

import java.time.{Duration, Instant}

/**
* Provides an auth token for the service account, for use when comunicating with remote storage
*/
trait AuthTokenProvider {
def apply(credentials: Credentials): UIO[Option[AuthToken]]
def apply(credentials: Credentials): IO[Option[AuthToken]]
}

object AuthTokenProvider {
def apply(authService: OpenIdAuthService): UIO[AuthTokenProvider] = {
KeyValueStore[ClientCredentials, ParsedToken]().map(cache => new CachingOpenIdAuthTokenProvider(authService, cache))
def apply(authService: OpenIdAuthService)(implicit clock: Clock[IO]): bio.UIO[AuthTokenProvider] = {
LocalCache[ClientCredentials, ParsedToken]()
.map(cache => new CachingOpenIdAuthTokenProvider(authService, cache))
.toBIO
}
def anonymousForTest: AuthTokenProvider = new AnonymousAuthTokenProvider
def fixedForTest(token: String): AuthTokenProvider = new AuthTokenProvider {
override def apply(credentials: Credentials): UIO[Option[AuthToken]] = UIO.pure(Some(AuthToken(token)))
override def apply(credentials: Credentials): IO[Option[AuthToken]] = IO.pure(Some(AuthToken(token)))
}
}

private class AnonymousAuthTokenProvider extends AuthTokenProvider {
override def apply(credentials: Credentials): UIO[Option[AuthToken]] = UIO.pure(None)
override def apply(credentials: Credentials): IO[Option[AuthToken]] = IO.pure(None)
}

/**
Expand All @@ -39,42 +41,42 @@ private class AnonymousAuthTokenProvider extends AuthTokenProvider {
*/
private class CachingOpenIdAuthTokenProvider(
service: OpenIdAuthService,
cache: KeyValueStore[ClientCredentials, ParsedToken]
cache: LocalCache[ClientCredentials, ParsedToken]
)(implicit
clock: Clock[UIO]
clock: Clock[IO]
) extends AuthTokenProvider
with MigrateEffectSyntax {

private val logger = Logger.cats[CachingOpenIdAuthTokenProvider]

override def apply(credentials: Credentials): UIO[Option[AuthToken]] = {
override def apply(credentials: Credentials): IO[Option[AuthToken]] = {

credentials match {
case Credentials.Anonymous => UIO.pure(None)
case Credentials.JWTToken(token) => UIO.pure(Some(AuthToken(token)))
case Credentials.Anonymous => IO.pure(None)
case Credentials.JWTToken(token) => IO.pure(Some(AuthToken(token)))
case credentials: ClientCredentials => clientCredentialsFlow(credentials)
}
}

private def clientCredentialsFlow(credentials: ClientCredentials) = {
private def clientCredentialsFlow(credentials: ClientCredentials): IO[Some[AuthToken]] = {
for {
existingValue <- cache.get(credentials)
now <- IOUtils.instant
now <- IOInstant.now
finalValue <- existingValue match {
case None =>
logger.info("Fetching auth token, no initial value.").toUIO >>
logger.info("Fetching auth token, no initial value.") *>
fetchValue(credentials)
case Some(value) if isExpired(value, now) =>
logger.info("Fetching new auth token, current value near expiry.").toUIO >>
logger.info("Fetching new auth token, current value near expiry.") *>
fetchValue(credentials)
case Some(value) => UIO.pure(value)
case Some(value) => IO.pure(value)
}
} yield {
Some(AuthToken(finalValue.rawToken))
}
}

private def fetchValue(credentials: ClientCredentials) = {
private def fetchValue(credentials: ClientCredentials): IO[ParsedToken] = {
cache.getOrElseUpdate(credentials, service.auth(credentials))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import akka.http.javadsl.model.headers.HttpCredentials
import akka.http.scaladsl.model.HttpMethods.POST
import akka.http.scaladsl.model.headers.Authorization
import akka.http.scaladsl.model.{HttpRequest, Uri}
import cats.effect.IO
import ch.epfl.bluebrain.nexus.delta.kernel.Secret
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.sdk.auth.Credentials.ClientCredentials
Expand All @@ -15,7 +16,6 @@ import ch.epfl.bluebrain.nexus.delta.sdk.realms.Realms
import ch.epfl.bluebrain.nexus.delta.sdk.realms.model.Realm
import ch.epfl.bluebrain.nexus.delta.sourcing.model.Label
import io.circe.Json
import monix.bio.{IO, UIO}

/**
* Exchanges client credentials for an auth token with a remote OpenId service, as defined in the specified realm
Expand All @@ -25,7 +25,7 @@ class OpenIdAuthService(httpClient: HttpClient, realms: Realms) extends MigrateE
/**
* Exchanges client credentials for an auth token with a remote OpenId service, as defined in the specified realm
*/
def auth(credentials: ClientCredentials): UIO[ParsedToken] = {
def auth(credentials: ClientCredentials): IO[ParsedToken] = {
for {
realm <- findRealm(credentials.realm)
response <- requestToken(realm.tokenEndpoint, credentials.user, credentials.password)
Expand All @@ -35,14 +35,14 @@ class OpenIdAuthService(httpClient: HttpClient, realms: Realms) extends MigrateE
}
}

private def findRealm(id: Label): UIO[Realm] = {
private def findRealm(id: Label): IO[Realm] = {
for {
realm <- realms.fetch(id).toUIO
_ <- UIO.when(realm.deprecated)(UIO.terminate(RealmIsDeprecated(realm.value)))
realm <- realms.fetch(id)
_ <- IO.raiseWhen(realm.deprecated)(RealmIsDeprecated(realm.value))
} yield realm.value
}

private def requestToken(tokenEndpoint: Uri, user: String, password: Secret[String]): UIO[Json] = {
private def requestToken(tokenEndpoint: Uri, user: String, password: Secret[String]): IO[Json] = {
httpClient
.toJson(
HttpRequest(
Expand All @@ -62,13 +62,13 @@ class OpenIdAuthService(httpClient: HttpClient, realms: Realms) extends MigrateE
.hideErrorsWith(AuthTokenHttpError)
}

private def parseResponse(json: Json): UIO[ParsedToken] = {
private def parseResponse(json: Json): IO[ParsedToken] = {
for {
rawToken <- json.hcursor.get[String]("access_token") match {
case Left(failure) => IO.terminate(AuthTokenNotFoundInResponse(failure))
case Right(value) => UIO.pure(value)
case Left(failure) => IO.raiseError(AuthTokenNotFoundInResponse(failure))
case Right(value) => IO.pure(value)
}
parsedToken <- IO.fromEither(ParsedToken.fromToken(AuthToken(rawToken))).hideErrors
parsedToken <- IO.fromEither(ParsedToken.fromToken(AuthToken(rawToken)))
} yield {
parsedToken
}
Expand Down

0 comments on commit e76ce0a

Please sign in to comment.