diff --git a/pom.xml b/pom.xml
index 4152eeb..c63effc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -55,6 +55,10 @@
org.springframework.boot
spring-boot-starter-web
+
+ org.springframework.mobile
+ spring-mobile-device
+
io.jsonwebtoken
jjwt
diff --git a/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java b/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java
index 4031d1c..144222a 100644
--- a/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java
+++ b/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java
@@ -6,9 +6,12 @@
import javax.servlet.http.HttpServletRequest;
+import org.apache.log4j.Logger;
+
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.ResponseEntity;
+import org.springframework.mobile.device.Device;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
@@ -25,6 +28,8 @@
@RequestMapping("${cerberus.route.authentication}")
public class AuthenticationController {
+ private final Logger logger = Logger.getLogger(this.getClass());
+
@Value("${cerberus.token.header}")
private String tokenHeader;
@@ -38,7 +43,7 @@ public class AuthenticationController {
private UserDetailsService userDetailsService;
@RequestMapping(method = RequestMethod.POST)
- public ResponseEntity> authenticationRequest(@RequestBody AuthenticationRequest authenticationRequest) throws AuthenticationException {
+ public ResponseEntity> authenticationRequest(@RequestBody AuthenticationRequest authenticationRequest, Device device) throws AuthenticationException {
// Perform the authentication
Authentication authentication = this.authenticationManager.authenticate(
@@ -51,7 +56,7 @@ public ResponseEntity> authenticationRequest(@RequestBody AuthenticationReques
// Reload password post-authentication so we can generate token
UserDetails userDetails = this.userDetailsService.loadUserByUsername(authenticationRequest.getUsername());
- String token = this.tokenUtils.generateToken(userDetails);
+ String token = this.tokenUtils.generateToken(userDetails, device);
// Return the token
return ResponseEntity.ok(new AuthenticationResponse(token));
@@ -60,8 +65,12 @@ public ResponseEntity> authenticationRequest(@RequestBody AuthenticationReques
@RequestMapping(value = "${cerberus.route.authentication.refresh}", method = RequestMethod.GET)
public ResponseEntity> authenticationRequest(HttpServletRequest request) {
String token = request.getHeader(this.tokenHeader);
- String refreshedToken = this.tokenUtils.refreshToken(token);
- return ResponseEntity.ok(new AuthenticationResponse(refreshedToken));
+ if (this.tokenUtils.isTokenExpired(token)) {
+ return ResponseEntity.badRequest().body("Token Expired");
+ } else {
+ String refreshedToken = this.tokenUtils.refreshToken(token);
+ return ResponseEntity.ok(new AuthenticationResponse(refreshedToken));
+ }
}
}
diff --git a/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java b/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java
index 8b8fef7..fc1e902 100644
--- a/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java
+++ b/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java
@@ -2,7 +2,11 @@
import io.jsonwebtoken.*;
+import java.util.HashMap;
+import java.util.Map;
+
import org.springframework.beans.factory.annotation.Value;
+import org.springframework.mobile.device.Device;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;
@@ -11,6 +15,11 @@
@Component
public class TokenUtils {
+ private final String AUDIENCE_UNKNOWN = "unknown";
+ private final String AUDIENCE_WEB = "web";
+ private final String AUDIENCE_MOBILE = "mobile";
+ private final String AUDIENCE_TABLET = "tablet";
+
@Value("${cerberus.token.secret}")
private String secret;
@@ -39,6 +48,17 @@ public Date getExpirationDateFromToken(String token) {
return expiration;
}
+ public String getAudienceFromToken(String token) {
+ String audience;
+ try {
+ final Claims claims = this.getClaimsFromToken(token);
+ audience = (String) claims.get("audience");
+ } catch (Exception e) {
+ audience = null;
+ }
+ return audience;
+ }
+
private Claims getClaimsFromToken(String token) {
Claims claims;
try {
@@ -56,17 +76,38 @@ private Date generateExpirationDate() {
return new Date(System.currentTimeMillis() + this.expiration * 1000);
}
- public String generateToken(UserDetails userDetails) {
- return Jwts.builder()
- .setSubject(userDetails.getUsername())
- .setExpiration(this.generateExpirationDate())
- .signWith(SignatureAlgorithm.HS512, this.secret)
- .compact();
+ public Boolean isTokenExpired(String token) {
+ final Date expiration = this.getExpirationDateFromToken(token);
+ return (expiration.before(new Date(System.currentTimeMillis())) && !(this.ignoreTokenExpiration(token)));
+ }
+
+ private String generateAudience(Device device) {
+ String audience = this.AUDIENCE_UNKNOWN;
+ if (device.isNormal()) {
+ audience = this.AUDIENCE_WEB;
+ } else if (device.isTablet()) {
+ audience = AUDIENCE_TABLET;
+ } else if (device.isMobile()) {
+ audience = AUDIENCE_MOBILE;
+ }
+ return audience;
+ }
+
+ private Boolean ignoreTokenExpiration(String token) {
+ String audience = this.getAudienceFromToken(token);
+ return (this.AUDIENCE_TABLET.equals(audience) || this.AUDIENCE_MOBILE.equals(audience));
+ }
+
+ public String generateToken(UserDetails userDetails, Device device) {
+ Map claims = new HashMap();
+ claims.put("sub", userDetails.getUsername());
+ claims.put("audience", this.generateAudience(device));
+ return this.generateToken(claims);
}
- public String generateToken(String subject) {
+ private String generateToken(Map claims) {
return Jwts.builder()
- .setSubject(subject)
+ .setClaims(claims)
.setExpiration(this.generateExpirationDate())
.signWith(SignatureAlgorithm.HS512, this.secret)
.compact();
@@ -76,17 +117,17 @@ public String refreshToken(String token) {
String refreshedToken;
try {
final Claims claims = this.getClaimsFromToken(token);
- refreshedToken = this.generateToken(claims.getSubject());
+ refreshedToken = this.generateToken(claims);
} catch (Exception e) {
refreshedToken = null;
}
return refreshedToken;
}
- public boolean validateToken(String token, UserDetails userDetails) {
+ public Boolean validateToken(String token, UserDetails userDetails) {
final String username = this.getUsernameFromToken(token);
final Date expiration = this.getExpirationDateFromToken(token);
- return (username.equals(userDetails.getUsername()) && expiration.after(new Date(System.currentTimeMillis())));
+ return (username.equals(userDetails.getUsername()) && !(this.isTokenExpired(token)));
}
}