Skip to content

Commit

Permalink
Remove callback to Delta to check the token, validate it locally (#4340)
Browse files Browse the repository at this point in the history
* Remove callback to Delta to check the token, validate it locally

---------

Co-authored-by: Simon Dumas <[email protected]>
  • Loading branch information
imsdu and Simon Dumas authored Oct 9, 2023
1 parent 3f06fd2 commit f480c08
Show file tree
Hide file tree
Showing 40 changed files with 607 additions and 626 deletions.
8 changes: 5 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ lazy val monixEval = "io.monix" %% "monix-eval"
lazy val munit = "org.scalameta" %% "munit" % munitVersion
lazy val nimbusJoseJwt = "com.nimbusds" % "nimbus-jose-jwt" % nimbusJoseJwtVersion
lazy val pureconfig = "com.github.pureconfig" %% "pureconfig" % pureconfigVersion
lazy val pureconfigCats = "com.github.pureconfig" %% "pureconfig-cats" % pureconfigVersion
lazy val scalaLogging = "com.typesafe.scala-logging" %% "scala-logging" % scalaLoggingVersion
lazy val scalaTest = "org.scalatest" %% "scalatest" % scalaTestVersion
lazy val scalaXml = "org.scala-lang.modules" %% "scala-xml" % scalaXmlVersion
Expand Down Expand Up @@ -207,14 +208,17 @@ lazy val kernel = project
akkaActorTyped, // Needed to create content type
akkaHttpCore,
caffeine,
catsCore,
catsRetry,
circeCore,
circeParser,
handleBars,
monixBio,
nimbusJoseJwt,
kamonCore,
log4cats,
pureconfig,
pureconfigCats,
scalaLogging,
munit % Test,
scalaTest % Test
Expand Down Expand Up @@ -257,7 +261,6 @@ lazy val sourcingPsql = project
.settings(shared, compilation, assertJavaVersion, coverage, release)
.settings(
libraryDependencies ++= Seq(
catsCore,
circeCore,
circeGenericExtras,
circeParser,
Expand Down Expand Up @@ -324,7 +327,6 @@ lazy val sdk = project
distageCore,
fs2,
monixBio,
nimbusJoseJwt,
akkaTestKitTyped % Test,
akkaHttpTestKit % Test,
munit % Test,
Expand Down Expand Up @@ -735,7 +737,7 @@ lazy val storage = project
servicePackaging,
coverageMinimumStmtTotal := 75
)
.dependsOn(kernel)
.dependsOn(kernel, testkit % "test->compile")
.settings(cargo := {
import scala.sys.process._

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ch.epfl.bluebrain.nexus.delta.sdk.identities.model
package ch.epfl.bluebrain.nexus.delta.kernel.jwt

import akka.http.scaladsl.model.headers.OAuth2BearerToken
import io.circe.{Decoder, Encoder}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
package ch.epfl.bluebrain.nexus.delta.sdk.identities
package ch.epfl.bluebrain.nexus.delta.kernel.jwt

import cats.implicits._
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.{AuthToken, TokenRejection}
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.TokenRejection._
import cats.data.NonEmptySet
import cats.syntax.all._
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.TokenRejection._
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.JWKSet
import ch.epfl.bluebrain.nexus.delta.kernel.syntax._
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
import com.nimbusds.jose.proc.{JWSVerificationKeySelector, SecurityContext}
import com.nimbusds.jwt.proc.{DefaultJWTClaimsVerifier, DefaultJWTProcessor}
import com.nimbusds.jwt.{JWTClaimsSet, SignedJWT}

import java.time.Instant
import scala.jdk.CollectionConverters._
import scala.util.Try

/**
Expand All @@ -18,7 +25,24 @@ final case class ParsedToken private (
expirationTime: Instant,
groups: Option[Set[String]],
jwtToken: SignedJWT
)
) {

def validate(audiences: Option[NonEmptySet[String]], keySet: JWKSet): Either[InvalidAccessToken, Unit] = {
val proc = new DefaultJWTProcessor[SecurityContext]
val keySelector = new JWSVerificationKeySelector(JWSAlgorithm.RS256, new ImmutableJWKSet[SecurityContext](keySet))
proc.setJWSKeySelector(keySelector)
audiences.foreach { aud =>
proc.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier(aud.toSet.asJava, null, null, null))
}
Either
.catchNonFatal(proc.process(jwtToken, null))
.bimap(
err => InvalidAccessToken(subject, issuer, err.getMessage),
_ => ()
)
}

}

object ParsedToken {

Expand All @@ -33,13 +57,13 @@ object ParsedToken {
def parseJwt: Either[TokenRejection, SignedJWT] =
Either
.catchNonFatal(SignedJWT.parse(token.value))
.leftMap(_ => InvalidAccessTokenFormat)
.leftMap { e => InvalidAccessTokenFormat(e.getMessage) }

def claims(jwt: SignedJWT): Either[TokenRejection, JWTClaimsSet] =
Either
.catchNonFatal(jwt.getJWTClaimsSet)
.filterOrElse(_ != null, InvalidAccessTokenFormat)
.leftMap(_ => InvalidAccessTokenFormat)
.catchNonFatal(Option(jwt.getJWTClaimsSet))
.leftMap { e => InvalidAccessTokenFormat(e.getMessage) }
.flatMap { _.toRight(InvalidAccessTokenFormat("No claim is defined.")) }

def subject(claimsSet: JWTClaimsSet) = {
val preferredUsername = Try(claimsSet.getStringClaim("preferred_username"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package ch.epfl.bluebrain.nexus.delta.kernel.jwt

/**
* Enumeration of token rejections.
*
* @param reason
* a descriptive message for reasons why a token is rejected by the system
*/
// $COVERAGE-OFF$
sealed abstract class TokenRejection(reason: String) extends Exception with Product with Serializable {
override def fillInStackTrace(): Throwable = this
override def getMessage: String = reason
}

object TokenRejection {

/**
* Rejection for cases where the AccessToken is not a properly formatted signed JWT.
*/
final case class InvalidAccessTokenFormat(details: String)
extends TokenRejection(
s"Access token is invalid; possible causes are: JWT not signed, encoded parts are not properly encoded or each part is not a valid json, details: '$details'"
)

/**
* Rejection for cases where the access token does not contain a subject in the claim set.
*/
final case object AccessTokenDoesNotContainSubject extends TokenRejection("The token doesn't contain a subject.")

/**
* Rejection for cases where the access token does not contain an issuer in the claim set.
*/
final case object AccessTokenDoesNotContainAnIssuer extends TokenRejection("The token doesn't contain an issuer.")

/**
* Rejection for cases where the issuer specified in the access token claim set is unknown; also applies to issuers
* of deprecated realms.
*/
final case object UnknownAccessTokenIssuer extends TokenRejection("The issuer referenced in the token was not found.")

/**
* Rejection for cases where the access token is invalid according to JWTClaimsVerifier
*/
final case class InvalidAccessToken(subject: String, issuer: String, details: String)
extends TokenRejection(s"The provided token is invalid for user '$subject/$issuer' .")

/**
* Rejection for cases where we couldn't fetch the groups from the OIDC provider
*/
final case class GetGroupsFromOidcError(subject: String, issuer: String)
extends TokenRejection(
"The token is invalid; possible causes are: the OIDC provider is unreachable."
)
}
// $COVERAGE-ON$
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ch.epfl.bluebrain.nexus.delta.sdk.syntax
package ch.epfl.bluebrain.nexus.delta.kernel.syntax

import cats.data.NonEmptySet

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package ch.epfl.bluebrain.nexus.delta.kernel

package object syntax extends KamonSyntax with ClassTagSyntax with IOSyntax with InstantSyntax
package object syntax extends KamonSyntax with ClassTagSyntax with IOSyntax with InstantSyntax with NonEmptySetSyntax
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ import cats.effect.{Clock, IO}
import ch.epfl.bluebrain.nexus.delta.kernel.Logger
import ch.epfl.bluebrain.nexus.delta.kernel.cache.LocalCache
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.{AuthToken, ParsedToken}
import ch.epfl.bluebrain.nexus.delta.kernel.utils.IOInstant
import ch.epfl.bluebrain.nexus.delta.sdk.auth.Credentials.ClientCredentials
import ch.epfl.bluebrain.nexus.delta.sdk.identities.ParsedToken
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.AuthToken
import monix.bio

import java.time.{Duration, Instant}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ import akka.http.scaladsl.model.{HttpRequest, Uri}
import cats.effect.IO
import ch.epfl.bluebrain.nexus.delta.kernel.Secret
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.{AuthToken, ParsedToken}
import ch.epfl.bluebrain.nexus.delta.sdk.auth.Credentials.ClientCredentials
import ch.epfl.bluebrain.nexus.delta.sdk.error.AuthTokenError.{AuthTokenHttpError, AuthTokenNotFoundInResponse, RealmIsDeprecated}
import ch.epfl.bluebrain.nexus.delta.sdk.http.HttpClient
import ch.epfl.bluebrain.nexus.delta.sdk.identities.ParsedToken
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.AuthToken
import ch.epfl.bluebrain.nexus.delta.sdk.realms.Realms
import ch.epfl.bluebrain.nexus.delta.sdk.realms.model.Realm
import ch.epfl.bluebrain.nexus.delta.sourcing.model.Label
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import ch.epfl.bluebrain.nexus.delta.sdk.acls.model.AclAddress
import ch.epfl.bluebrain.nexus.delta.sdk.error.IdentityError.{AuthenticationFailed, InvalidToken}
import ch.epfl.bluebrain.nexus.delta.sdk.error.ServiceError.AuthorizationFailed
import ch.epfl.bluebrain.nexus.delta.sdk.identities.Identities
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.{AuthToken, Caller, ServiceAccount, TokenRejection}
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.{Caller, ServiceAccount}
import ch.epfl.bluebrain.nexus.delta.sdk.permissions.model.Permission
import ch.epfl.bluebrain.nexus.delta.sourcing.model.Identity.Subject
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration._
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.{AuthToken, TokenRejection}

import scala.concurrent.Future

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package ch.epfl.bluebrain.nexus.delta.sdk.error

import akka.http.scaladsl.model.StatusCodes
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.TokenRejection
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.TokenRejection.InvalidAccessToken
import ch.epfl.bluebrain.nexus.delta.kernel.utils.ClassUtils
import ch.epfl.bluebrain.nexus.delta.rdf.IriOrBNode.BNode
import ch.epfl.bluebrain.nexus.delta.rdf.Vocabulary.contexts
import ch.epfl.bluebrain.nexus.delta.rdf.jsonld.context.ContextValue
import ch.epfl.bluebrain.nexus.delta.rdf.jsonld.context.JsonLdContext.keywords
import ch.epfl.bluebrain.nexus.delta.rdf.jsonld.encoder.JsonLdEncoder
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.TokenRejection
import io.circe.syntax._
import ch.epfl.bluebrain.nexus.delta.sdk.marshalling.HttpResponseFields
import ch.epfl.bluebrain.nexus.delta.sdk.syntax.httpResponseFieldsSyntax
import io.circe.syntax.EncoderOps
import io.circe.{Encoder, JsonObject}

/**
Expand Down Expand Up @@ -34,6 +40,19 @@ object IdentityError {
*/
final case class InvalidToken(rejection: TokenRejection) extends IdentityError(rejection.getMessage)

implicit val tokenRejectionEncoder: Encoder.AsObject[TokenRejection] =
Encoder.AsObject.instance { r =>
val tpe = ClassUtils.simpleName(r)
val json = JsonObject.empty.add(keywords.tpe, tpe.asJson).add("reason", r.getMessage.asJson)
r match {
case InvalidAccessToken(_, _, error) => json.add("details", error.asJson)
case _ => json
}
}

implicit final val tokenRejectionJsonLdEncoder: JsonLdEncoder[TokenRejection] =
JsonLdEncoder.computeFromCirce(id = BNode.random, ctx = ContextValue(contexts.error))

implicit val identityErrorEncoder: Encoder.AsObject[IdentityError] =
Encoder.AsObject.instance[IdentityError] {
case InvalidToken(r) =>
Expand All @@ -44,4 +63,13 @@ object IdentityError {

implicit val identityErrorJsonLdEncoder: JsonLdEncoder[IdentityError] =
JsonLdEncoder.computeFromCirce(ContextValue(contexts.error))

implicit val responseFieldsTokenRejection: HttpResponseFields[TokenRejection] =
HttpResponseFields(_ => StatusCodes.Unauthorized)

implicit val responseFieldsIdentities: HttpResponseFields[IdentityError] =
HttpResponseFields {
case IdentityError.AuthenticationFailed => StatusCodes.Unauthorized
case IdentityError.InvalidToken(rejection) => rejection.status
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package ch.epfl.bluebrain.nexus.delta.sdk.identities

import cats.effect.IO
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.{AuthToken, Caller}
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.AuthToken
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.Caller

/**
* Operations pertaining to authentication, token validation and identities.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,27 @@ package ch.epfl.bluebrain.nexus.delta.sdk.identities

import akka.http.scaladsl.model.headers.{Authorization, OAuth2BearerToken}
import akka.http.scaladsl.model.{HttpRequest, StatusCodes, Uri}
import cats.data.{NonEmptySet, OptionT}
import cats.data.OptionT
import cats.effect.IO
import cats.syntax.all._
import ch.epfl.bluebrain.nexus.delta.kernel.Logger
import ch.epfl.bluebrain.nexus.delta.kernel.cache.{CacheConfig, LocalCache}
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration._
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.TokenRejection.{GetGroupsFromOidcError, InvalidAccessToken, UnknownAccessTokenIssuer}
import ch.epfl.bluebrain.nexus.delta.kernel.jwt.{AuthToken, ParsedToken}
import ch.epfl.bluebrain.nexus.delta.kernel.kamon.KamonMetricComponent
import ch.epfl.bluebrain.nexus.delta.kernel.search.Pagination.FromPagination
import ch.epfl.bluebrain.nexus.delta.sdk.http.HttpClient
import ch.epfl.bluebrain.nexus.delta.sdk.http.HttpClientError.HttpClientStatusError
import ch.epfl.bluebrain.nexus.delta.sdk.identities.IdentitiesImpl.{extractGroups, logger, GroupsCache, RealmCache}
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.TokenRejection.{GetGroupsFromOidcError, InvalidAccessToken, UnknownAccessTokenIssuer}
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.{AuthToken, Caller}
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.Caller
import ch.epfl.bluebrain.nexus.delta.sdk.model.ResourceF
import ch.epfl.bluebrain.nexus.delta.sdk.model.search.SearchParams.RealmSearchParams
import ch.epfl.bluebrain.nexus.delta.sdk.realms.Realms
import ch.epfl.bluebrain.nexus.delta.sdk.realms.model.Realm
import ch.epfl.bluebrain.nexus.delta.sdk.syntax._
import ch.epfl.bluebrain.nexus.delta.sourcing.model.Identity.{Anonymous, Authenticated, Group, User}
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
import com.nimbusds.jose.jwk.{JWK, JWKSet}
import com.nimbusds.jose.proc.{JWSVerificationKeySelector, SecurityContext}
import com.nimbusds.jwt.proc.{DefaultJWTClaimsVerifier, DefaultJWTProcessor}
import io.circe.{Decoder, HCursor, Json}

import scala.util.Try
Expand All @@ -48,22 +45,8 @@ class IdentitiesImpl private[identities] (
new JWKSet(jwks.toList.asJava)
}

def validate(audiences: Option[NonEmptySet[String]], token: ParsedToken, keySet: JWKSet) = {
val proc = new DefaultJWTProcessor[SecurityContext]
val keySelector = new JWSVerificationKeySelector(JWSAlgorithm.RS256, new ImmutableJWKSet[SecurityContext](keySet))
proc.setJWSKeySelector(keySelector)
audiences.foreach { aud =>
proc.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier(aud.toSet.asJava, null, null, null))
}
IO.fromEither(
Either
.catchNonFatal(proc.process(token.jwtToken, null))
.leftMap(err => InvalidAccessToken(token.subject, token.issuer, err.getMessage))
)
}

def fetchRealm(parsedToken: ParsedToken): IO[Realm] = {
val getRealm = realm.getOrElseAttemptUpdate(parsedToken.rawToken, findActiveRealm(parsedToken.issuer))
val getRealm = realm.getOrElseAttemptUpdate(parsedToken.issuer, findActiveRealm(parsedToken.issuer))
OptionT(getRealm).getOrRaise(UnknownAccessTokenIssuer)
}

Expand All @@ -85,7 +68,7 @@ class IdentitiesImpl private[identities] (
val result = for {
parsedToken <- IO.fromEither(ParsedToken.fromToken(token))
activeRealm <- fetchRealm(parsedToken)
_ <- validate(activeRealm.acceptedAudiences, parsedToken, realmKeyset(activeRealm))
_ <- IO.fromEither(parsedToken.validate(activeRealm.acceptedAudiences, realmKeyset(activeRealm)))
groups <- fetchGroups(parsedToken, activeRealm)
} yield {
val user = User(parsedToken.subject, activeRealm.label)
Expand Down
Loading

0 comments on commit f480c08

Please sign in to comment.