Skip to content

Commit

Permalink
Add validation options
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldijou committed Dec 27, 2015
1 parent 4a9392f commit e399d82
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 105 deletions.
282 changes: 196 additions & 86 deletions core/common/src/main/scala/Jwt.scala

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions core/common/src/main/scala/JwtOptions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package pdi.jwt

case class JwtOptions(
signature: Boolean = true,
expiration: Boolean = true,
notBefore: Boolean = true
)

object JwtOptions {
val DEFAULT = new JwtOptions()
}
60 changes: 60 additions & 0 deletions core/common/src/test/scala/JwtSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,5 +218,65 @@ class JwtSpec extends UnitSpec with Fixture {
assert(Jwt.decode(token).isFailure)
intercept[JwtNonSupportedAlgorithm] { Jwt.decode(token).get }
}

it("should skip expiration validation depending on options") {
val mock = mockAfterExpiration
val options = JwtOptions(expiration = false)

data foreach { d =>
Jwt.validate(d.token, secretKey, JwtAlgorithm.allHmac, options)
assertResult(true, d.algo.fullName) { Jwt.isValid(d.token, secretKey, JwtAlgorithm.allHmac, options) }
Jwt.validate(d.token, secretKeyOf(d.algo), options)
assertResult(true, d.algo.fullName) { Jwt.isValid(d.token, secretKeyOf(d.algo), options) }
}

dataRSA foreach { d =>
Jwt.validate(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options)
assertResult(true, d.algo.fullName) { Jwt.isValid(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options) }
}

mock.tearDown
}

it("should skip notBefore validation depending on options") {
val mock = mockBeforeNotBefore
val options = JwtOptions(notBefore = false)

data foreach { d =>
val claimNotBefore = claimClass.copy(notBefore = Option(notBefore))
val token = Jwt.encode(claimNotBefore, secretKey, d.algo)

Jwt.validate(token, secretKey, JwtAlgorithm.allHmac, options)
assertResult(true, d.algo.fullName) { Jwt.isValid(token, secretKey, JwtAlgorithm.allHmac, options) }
Jwt.validate(token, secretKeyOf(d.algo), options)
assertResult(true, d.algo.fullName) { Jwt.isValid(token, secretKeyOf(d.algo), options) }
}

dataRSA foreach { d =>
val claimNotBefore = claimClass.copy(notBefore = Option(notBefore))
val token = Jwt.encode(claimNotBefore, privateKeyRSA, d.algo)

Jwt.validate(token, publicKeyRSA, JwtAlgorithm.allRSA, options)
assertResult(true, d.algo.fullName) { Jwt.isValid(token, publicKeyRSA, JwtAlgorithm.allRSA, options) }
}

mock.tearDown
}

it("should skip signature validation depending on options") {
val mock = mockValidTime
val options = JwtOptions(signature = false)

data foreach { d =>
Jwt.validate(d.token, "wrong key", JwtAlgorithm.allHmac, options)
assertResult(true, d.algo.fullName) { Jwt.isValid(d.token, "wrong key", JwtAlgorithm.allHmac, options) }
}

dataRSA foreach { d =>
assertResult(true, d.algo.fullName) { Jwt.isValid(d.token, "wrong key", JwtAlgorithm.allRSA, options) }
}

mock.tearDown
}
}
}
22 changes: 20 additions & 2 deletions docs/src/main/tut/jwt-core-jwt.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
### Basic usage

```tut
import pdi.jwt.{Jwt, JwtAlgorithm, JwtHeader, JwtClaim}
import pdi.jwt.{Jwt, JwtAlgorithm, JwtHeader, JwtClaim, JwtOptions}
val token = Jwt.encode("""{"user":1}""", "secretKey", JwtAlgorithm.HS256)
Jwt.decodeRawAll(token, "secretKey", Seq(JwtAlgorithm.HS256))
Jwt.decodeRawAll(token, "wrongKey", Seq(JwtAlgorithm.HS256))
Expand Down Expand Up @@ -33,7 +33,9 @@ Jwt.encode(JwtHeader(JwtAlgorithm.HS1), JwtClaim("""{"user":1}"""), "key")

### Decoding

In JWT Scala, espcially when using raw strings which are not typesafe at all, there are a lot of possible errors. This is why nearly all `decode` functions will return a `Try` rather than directly the expected result. In case of failure, the wrapped exception should tell you what went wront.
In JWT Scala, espcially when using raw strings which are not typesafe at all, there are a lot of possible errors. This is why nearly all `decode` functions will return a `Try` rather than directly the expected result. In case of failure, the wrapped exception should tell you what went wrong.

Take note that nearly all decoding methods (including those from helper libs) support either a String key, or a PrivateKey with a Hmac algorithm or a PublicKey with a RSA or ECDSA algorithm.

```tut
// Decode all parts of the token as string
Expand Down Expand Up @@ -93,3 +95,19 @@ Jwt.isValid(Jwt.encode(JwtClaim().startsIn(5)))
Jwt.validate("a.b.c")
Jwt.isValid("a.b.c")
```

### Options

All validating and decoding methods support a final optional argument as a `JwtOptions` which allow you to disable validation checks. This is useful if you need to access data from an expired token for example. You can disable `expiration`, `notBefore` and `signature` checks. Be warned that if you disable the last one, you have no guarantee that the user didn't change the content of the token.

```tut:nofail
val expiredToken = Jwt.encode(JwtClaim().by("me").expiresIn(-1));
// Fail since the token is expired
Jwt.isValid(expiredToken)
Jwt.decode(expiredToken)
// Let's disable expiration check
Jwt.isValid(expiredToken, JwtOptions(expiration = false))
Jwt.decode(expiredToken, JwtOptions(expiration = false))
```
74 changes: 60 additions & 14 deletions json/common/src/main/scala/JwtJsonCommon.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,41 +44,87 @@ trait JwtJsonCommon[J] extends JwtCore[JwtHeader, JwtClaim] {
def encode(claim: J, key: PrivateKey, algorithm: JwtAsymetricAlgorithm): String =
encode(stringify(claim), key, algorithm)

def decodeJsonAll(token: String, options: JwtOptions): Try[(J, J, String)] =
decodeRawAll(token, options).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }

def decodeJsonAll(token: String): Try[(J, J, String)] =
decodeRawAll(token).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }
decodeJsonAll(token, JwtOptions.DEFAULT)

def decodeJsonAll(token: String, key: String, algorithms: Seq[JwtHmacAlgorithm], options: JwtOptions): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms, options).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }

def decodeJsonAll(token: String, key: String, algorithms: Seq[JwtHmacAlgorithm]): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }
decodeJsonAll(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJsonAll(token: String, key: String, algorithms: => Seq[JwtAsymetricAlgorithm], options: JwtOptions): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms, options).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }

def decodeJsonAll(token: String, key: String, algorithms: => Seq[JwtAsymetricAlgorithm]): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }
decodeJsonAll(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJsonAll(token: String, key: SecretKey, algorithms: Seq[JwtHmacAlgorithm], options: JwtOptions): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms, options).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }

def decodeJsonAll(token: String, key: SecretKey, algorithms: Seq[JwtHmacAlgorithm]): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }
decodeJsonAll(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJsonAll(token: String, key: SecretKey): Try[(J, J, String)] = decodeJsonAll(token, key, JwtAlgorithm.allHmac)
def decodeJsonAll(token: String, key: SecretKey, options: JwtOptions): Try[(J, J, String)] =
decodeJsonAll(token, key, JwtAlgorithm.allHmac, options)

def decodeJsonAll(token: String, key: SecretKey): Try[(J, J, String)] =
decodeJsonAll(token, key, JwtOptions.DEFAULT)

def decodeJsonAll(token: String, key: PublicKey, algorithms: Seq[JwtAsymetricAlgorithm], options: JwtOptions): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms, options).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }

def decodeJsonAll(token: String, key: PublicKey, algorithms: Seq[JwtAsymetricAlgorithm]): Try[(J, J, String)] =
decodeRawAll(token, key, algorithms).map { tuple => (parse(tuple._1), parse(tuple._2), tuple._3) }
decodeJsonAll(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJsonAll(token: String, key: PublicKey, options: JwtOptions): Try[(J, J, String)] =
decodeJsonAll(token, key, JwtAlgorithm.allAsymetric, options)

def decodeJsonAll(token: String, key: PublicKey): Try[(J, J, String)] = decodeJsonAll(token, key, JwtAlgorithm.allAsymetric)
def decodeJsonAll(token: String, key: PublicKey): Try[(J, J, String)] =
decodeJsonAll(token, key, JwtOptions.DEFAULT)

def decodeJson(token: String, options: JwtOptions): Try[J] =
decodeJsonAll(token, options).map(_._2)

def decodeJson(token: String): Try[J] =
decodeJsonAll(token).map(_._2)
decodeJson(token, JwtOptions.DEFAULT)

def decodeJson(token: String, key: String, algorithms: Seq[JwtHmacAlgorithm], options: JwtOptions): Try[J] =
decodeJsonAll(token, key, algorithms, options).map(_._2)

def decodeJson(token: String, key: String, algorithms: Seq[JwtHmacAlgorithm]): Try[J] =
decodeJsonAll(token, key, algorithms).map(_._2)
decodeJson(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJson(token: String, key: String, algorithms: => Seq[JwtAsymetricAlgorithm], options: JwtOptions): Try[J] =
decodeJsonAll(token, key, algorithms, options).map(_._2)

def decodeJson(token: String, key: String, algorithms: => Seq[JwtAsymetricAlgorithm]): Try[J] =
decodeJsonAll(token, key, algorithms).map(_._2)
decodeJson(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJson(token: String, key: SecretKey, algorithms: Seq[JwtHmacAlgorithm], options: JwtOptions): Try[J] =
decodeJsonAll(token, key, algorithms, options).map(_._2)

def decodeJson(token: String, key: SecretKey, algorithms: Seq[JwtHmacAlgorithm]): Try[J] =
decodeJsonAll(token, key, algorithms).map(_._2)
decodeJson(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJson(token: String, key: SecretKey): Try[J] = decodeJson(token, key, JwtAlgorithm.allHmac)
def decodeJson(token: String, key: SecretKey, options: JwtOptions): Try[J] =
decodeJson(token, key, JwtAlgorithm.allHmac, options)

def decodeJson(token: String, key: SecretKey): Try[J] =
decodeJson(token, key, JwtOptions.DEFAULT)

def decodeJson(token: String, key: PublicKey, algorithms: Seq[JwtAsymetricAlgorithm], options: JwtOptions): Try[J] =
decodeJsonAll(token, key, algorithms, options).map(_._2)

def decodeJson(token: String, key: PublicKey, algorithms: Seq[JwtAsymetricAlgorithm]): Try[J] =
decodeJsonAll(token, key, algorithms).map(_._2)
decodeJson(token, key, algorithms, JwtOptions.DEFAULT)

def decodeJson(token: String, key: PublicKey, options: JwtOptions): Try[J] =
decodeJson(token, key, JwtAlgorithm.allAsymetric, options)

def decodeJson(token: String, key: PublicKey): Try[J] = decodeJson(token, key, JwtAlgorithm.allAsymetric)
def decodeJson(token: String, key: PublicKey): Try[J] =
decodeJson(token, key, JwtOptions.DEFAULT)
}
29 changes: 29 additions & 0 deletions json/common/src/test/scala/JwtJsonCommonSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,34 @@ abstract class JwtJsonCommonSpec[J] extends UnitSpec with JsonCommonFixture[J] {

mock.tearDown
}

it("should success to decodeJsonAll and decodeJson when now is after expiration date with options") {
val mock = mockAfterExpiration
val options = JwtOptions(expiration = false)

dataJson foreach { d =>
jwtJsonCommon.decodeJsonAll(d.token, secretKey, JwtAlgorithm.allHmac, options).get
assert(jwtJsonCommon.decodeJsonAll(d.token, secretKey, JwtAlgorithm.allHmac, options).isSuccess)

jwtJsonCommon.decodeJson(d.token, secretKey, JwtAlgorithm.allHmac, options).get
assert(jwtJsonCommon.decodeJson(d.token, secretKey, JwtAlgorithm.allHmac, options).isSuccess)

jwtJsonCommon.decodeAll(d.token, secretKey, JwtAlgorithm.allHmac, options).get
assert(jwtJsonCommon.decodeAll(d.token, secretKey, JwtAlgorithm.allHmac, options).isSuccess)
}

dataRSAJson foreach { d =>
jwtJsonCommon.decodeJsonAll(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options).get
assert(jwtJsonCommon.decodeJsonAll(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options).isSuccess)

jwtJsonCommon.decodeJson(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options).get
assert(jwtJsonCommon.decodeJson(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options).isSuccess)

jwtJsonCommon.decodeAll(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options).get
assert(jwtJsonCommon.decodeAll(d.token, publicKeyRSA, JwtAlgorithm.allRSA, options).isSuccess)
}

mock.tearDown
}
}
}
8 changes: 5 additions & 3 deletions play/src/main/scala/JwtSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,15 @@ object JwtSession {
private def key: Option[String] =
Play.maybeApplication.flatMap(_.configuration.getString("play.crypto.secret"))

def deserialize(token: String): JwtSession = (key match {
case Some(k) => JwtJson.decodeJsonAll(token, k, Seq(ALGORITHM))
case _ => JwtJson.decodeJsonAll(token)
def deserialize(token: String, options: JwtOptions): JwtSession = (key match {
case Some(k) => JwtJson.decodeJsonAll(token, k, Seq(ALGORITHM), options)
case _ => JwtJson.decodeJsonAll(token, options)
}).map { tuple =>
JwtSession(tuple._1, tuple._2, tuple._3)
}.getOrElse(JwtSession())

def deserialize(token: String): JwtSession = deserialize(token, JwtOptions.DEFAULT)

private def asJsObject[A](value: A)(implicit writer: Writes[A]): JsObject = writer.writes(value) match {
case value: JsObject => value
case _ => Json.obj()
Expand Down

0 comments on commit e399d82

Please sign in to comment.