diff --git a/pom.xml b/pom.xml index 4152eeb..c63effc 100644 --- a/pom.xml +++ b/pom.xml @@ -55,6 +55,10 @@ org.springframework.boot spring-boot-starter-web + + org.springframework.mobile + spring-mobile-device + io.jsonwebtoken jjwt diff --git a/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java b/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java index 4031d1c..144222a 100644 --- a/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java +++ b/src/main/java/com/brahalla/Cerberus/controller/rest/AuthenticationController.java @@ -6,9 +6,12 @@ import javax.servlet.http.HttpServletRequest; +import org.apache.log4j.Logger; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.ResponseEntity; +import org.springframework.mobile.device.Device; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; @@ -25,6 +28,8 @@ @RequestMapping("${cerberus.route.authentication}") public class AuthenticationController { + private final Logger logger = Logger.getLogger(this.getClass()); + @Value("${cerberus.token.header}") private String tokenHeader; @@ -38,7 +43,7 @@ public class AuthenticationController { private UserDetailsService userDetailsService; @RequestMapping(method = RequestMethod.POST) - public ResponseEntity authenticationRequest(@RequestBody AuthenticationRequest authenticationRequest) throws AuthenticationException { + public ResponseEntity authenticationRequest(@RequestBody AuthenticationRequest authenticationRequest, Device device) throws AuthenticationException { // Perform the authentication Authentication authentication = this.authenticationManager.authenticate( @@ -51,7 +56,7 @@ public ResponseEntity authenticationRequest(@RequestBody AuthenticationReques // Reload password post-authentication so we can generate token UserDetails userDetails = this.userDetailsService.loadUserByUsername(authenticationRequest.getUsername()); - String token = this.tokenUtils.generateToken(userDetails); + String token = this.tokenUtils.generateToken(userDetails, device); // Return the token return ResponseEntity.ok(new AuthenticationResponse(token)); @@ -60,8 +65,12 @@ public ResponseEntity authenticationRequest(@RequestBody AuthenticationReques @RequestMapping(value = "${cerberus.route.authentication.refresh}", method = RequestMethod.GET) public ResponseEntity authenticationRequest(HttpServletRequest request) { String token = request.getHeader(this.tokenHeader); - String refreshedToken = this.tokenUtils.refreshToken(token); - return ResponseEntity.ok(new AuthenticationResponse(refreshedToken)); + if (this.tokenUtils.isTokenExpired(token)) { + return ResponseEntity.badRequest().body("Token Expired"); + } else { + String refreshedToken = this.tokenUtils.refreshToken(token); + return ResponseEntity.ok(new AuthenticationResponse(refreshedToken)); + } } } diff --git a/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java b/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java index 8b8fef7..fc1e902 100644 --- a/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java +++ b/src/main/java/com/brahalla/Cerberus/security/TokenUtils.java @@ -2,7 +2,11 @@ import io.jsonwebtoken.*; +import java.util.HashMap; +import java.util.Map; + import org.springframework.beans.factory.annotation.Value; +import org.springframework.mobile.device.Device; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.stereotype.Component; @@ -11,6 +15,11 @@ @Component public class TokenUtils { + private final String AUDIENCE_UNKNOWN = "unknown"; + private final String AUDIENCE_WEB = "web"; + private final String AUDIENCE_MOBILE = "mobile"; + private final String AUDIENCE_TABLET = "tablet"; + @Value("${cerberus.token.secret}") private String secret; @@ -39,6 +48,17 @@ public Date getExpirationDateFromToken(String token) { return expiration; } + public String getAudienceFromToken(String token) { + String audience; + try { + final Claims claims = this.getClaimsFromToken(token); + audience = (String) claims.get("audience"); + } catch (Exception e) { + audience = null; + } + return audience; + } + private Claims getClaimsFromToken(String token) { Claims claims; try { @@ -56,17 +76,38 @@ private Date generateExpirationDate() { return new Date(System.currentTimeMillis() + this.expiration * 1000); } - public String generateToken(UserDetails userDetails) { - return Jwts.builder() - .setSubject(userDetails.getUsername()) - .setExpiration(this.generateExpirationDate()) - .signWith(SignatureAlgorithm.HS512, this.secret) - .compact(); + public Boolean isTokenExpired(String token) { + final Date expiration = this.getExpirationDateFromToken(token); + return (expiration.before(new Date(System.currentTimeMillis())) && !(this.ignoreTokenExpiration(token))); + } + + private String generateAudience(Device device) { + String audience = this.AUDIENCE_UNKNOWN; + if (device.isNormal()) { + audience = this.AUDIENCE_WEB; + } else if (device.isTablet()) { + audience = AUDIENCE_TABLET; + } else if (device.isMobile()) { + audience = AUDIENCE_MOBILE; + } + return audience; + } + + private Boolean ignoreTokenExpiration(String token) { + String audience = this.getAudienceFromToken(token); + return (this.AUDIENCE_TABLET.equals(audience) || this.AUDIENCE_MOBILE.equals(audience)); + } + + public String generateToken(UserDetails userDetails, Device device) { + Map claims = new HashMap(); + claims.put("sub", userDetails.getUsername()); + claims.put("audience", this.generateAudience(device)); + return this.generateToken(claims); } - public String generateToken(String subject) { + private String generateToken(Map claims) { return Jwts.builder() - .setSubject(subject) + .setClaims(claims) .setExpiration(this.generateExpirationDate()) .signWith(SignatureAlgorithm.HS512, this.secret) .compact(); @@ -76,17 +117,17 @@ public String refreshToken(String token) { String refreshedToken; try { final Claims claims = this.getClaimsFromToken(token); - refreshedToken = this.generateToken(claims.getSubject()); + refreshedToken = this.generateToken(claims); } catch (Exception e) { refreshedToken = null; } return refreshedToken; } - public boolean validateToken(String token, UserDetails userDetails) { + public Boolean validateToken(String token, UserDetails userDetails) { final String username = this.getUsernameFromToken(token); final Date expiration = this.getExpirationDateFromToken(token); - return (username.equals(userDetails.getUsername()) && expiration.after(new Date(System.currentTimeMillis()))); + return (username.equals(userDetails.getUsername()) && !(this.isTokenExpired(token))); } }