diff --git a/Sources/App/Controllers/View Controllers/LoginViewController.swift b/Sources/App/Controllers/View Controllers/LoginViewController.swift index 859f7f2..cd33353 100644 --- a/Sources/App/Controllers/View Controllers/LoginViewController.swift +++ b/Sources/App/Controllers/View Controllers/LoginViewController.swift @@ -28,11 +28,7 @@ class LoginViewController: RouteCollection { } func login(req: Request, content: LoginRequest) throws -> Future { - guard let correctCSRFKey = try req.session()["csrf"] else { throw Abort(.badRequest) } - let submittedKey = content.csrf - - guard correctCSRFKey == submittedKey else { throw Abort(.unauthorized) } - try req.session()["csrf"] = nil + try req.verifyCSRF() let userQuery = User.query(on: req) .filter(\.email == content.email) diff --git a/Sources/App/Controllers/View Controllers/RegisterViewController.swift b/Sources/App/Controllers/View Controllers/RegisterViewController.swift index 6b1f942..ad46b34 100644 --- a/Sources/App/Controllers/View Controllers/RegisterViewController.swift +++ b/Sources/App/Controllers/View Controllers/RegisterViewController.swift @@ -24,6 +24,8 @@ class RegisterViewController: RouteCollection { } func register(req: Request, content: RegisterRequest) throws -> Future { + try req.verifyCSRF() + guard content.password == content.confirmPassword else { throw RedirectError(to: "/register", error: "Passwords don't match") } let existingUserQuery = User diff --git a/Sources/App/Setup/configure.swift b/Sources/App/Setup/configure.swift index c5ae91f..ff8336d 100644 --- a/Sources/App/Setup/configure.swift +++ b/Sources/App/Setup/configure.swift @@ -5,6 +5,7 @@ import Redis import VaporSecurityHeaders import URLEncodedForm import Authentication +import Flash public func configure(_ config: inout Config, _ env: inout Environment, _ services: inout Services) throws { /// Register providers first @@ -63,6 +64,12 @@ public func configure(_ config: inout Config, _ env: inout Environment, _ servic return KeyedCacheSessions(keyedCache: keyedCache) } + services.register(CSRF.self) { _ -> CSRFVerifier in + return CSRFVerifier() + } + + config.prefer(CSRFVerifier.self, for: CSRF.self) + /// Setup Security Headers let cspConfig = ContentSecurityPolicyConfiguration(value: CSPConfig.setupCSP().generateString()) let xssProtectionConfig = XSSProtectionConfiguration(option: .block) @@ -116,4 +123,13 @@ public func configure(_ config: inout Config, _ env: inout Environment, _ servic /// Register KeyStorage guard let apiKey = Environment.get(Constants.restMiddlewareEnvKey) else { throw Abort(.internalServerError) } services.register(KeyStorage(restMiddlewareApiKey: apiKey)) + + //Leaf Tag Config + var defaultTags = LeafTagConfig.default() + defaultTags.use(FlashTag(), as: "flash") + + services.register(defaultTags) + + /// Flash Provider + try services.register(FlashProvider()) } diff --git a/Sources/App/Utils/CSRF.swift b/Sources/App/Utils/CSRF.swift deleted file mode 100644 index 2dbf5e8..0000000 --- a/Sources/App/Utils/CSRF.swift +++ /dev/null @@ -1,60 +0,0 @@ -// -// CSRF.swift -// App -// -// Created by Jimmy McDermott on 6/12/18. -// - -import Foundation -// -// CSRF.swift -// App -// -// Created by Jimmy McDermott on 11/7/17. -// - -import Foundation -import Vapor -import Crypto - -extension Request { - static let csrfKey = "_csrf" - - func verifyCSRF(submittedToken: String? = nil, key: String = Request.csrfKey) throws { - let validSession = try session() - guard let requiredToken: String = validSession[key] else { throw Abort(.forbidden) } - - if let token = submittedToken { - guard token == requiredToken else { throw Abort(.forbidden) } - } else { - let _ = try content.decode([String: String].self).map(to: Void.self) { form in - guard let submittedToken: String = form[key] else { throw Abort(.forbidden) } - guard requiredToken == submittedToken else { throw Abort(.forbidden) } - } - } - } - - private func generateRandom() throws -> String { - return try CryptoRandom().generateData(count: 4).hexEncodedString() - } - - func setCSRF(key: String = Request.csrfKey) throws -> String { - let string = "\(try generateRandom())-\(try generateRandom())-\(try generateRandom())-\(try generateRandom())" - try session()[key] = string - - return string - } -} - -protocol CSRFViewContext { - var csrf: String { get set } -} - -struct CSRFContext: CSRFViewContext, ViewContext { - var common: CommonViewContext? - var csrf: String - - init(csrf: String) { - self.csrf = csrf - } -} diff --git a/Sources/App/Utils/CSRF/CSRF.swift b/Sources/App/Utils/CSRF/CSRF.swift new file mode 100644 index 0000000..bcbf148 --- /dev/null +++ b/Sources/App/Utils/CSRF/CSRF.swift @@ -0,0 +1,15 @@ +import Foundation +import Crypto +import Vapor + +protocol CSRF: Service { + func verifyCSRF(submittedToken: String?, key: String, request: Request) throws + func setCSRF(key: String, request: Request) throws -> String + func generateRandom() throws -> String +} + +extension CSRF { + func generateRandom() throws -> String { + return try CryptoRandom().generateData(count: 4).hexEncodedString() + } +} diff --git a/Sources/App/Utils/CSRF/CSRFContext.swift b/Sources/App/Utils/CSRF/CSRFContext.swift new file mode 100644 index 0000000..34511cf --- /dev/null +++ b/Sources/App/Utils/CSRF/CSRFContext.swift @@ -0,0 +1,10 @@ +import Foundation + +struct CSRFContext: CSRFViewContext, ViewContext { + var common: CommonViewContext? + var csrf: String + + init(csrf: String) { + self.csrf = csrf + } +} diff --git a/Sources/App/Utils/CSRF/CSRFVerifier.swift b/Sources/App/Utils/CSRF/CSRFVerifier.swift new file mode 100644 index 0000000..9b49ef5 --- /dev/null +++ b/Sources/App/Utils/CSRF/CSRFVerifier.swift @@ -0,0 +1,24 @@ +import Foundation +import Vapor + +struct CSRFVerifier: CSRF { + func setCSRF(key: String, request: Request) throws -> String { + let string = "\(try generateRandom())-\(try generateRandom())-\(try generateRandom())-\(try generateRandom())" + try request.session()[key] = string + + return string + } + + func verifyCSRF(submittedToken: String?, key: String, request: Request) throws { + guard let requiredToken: String = try request.session()[key] else { throw Abort(.forbidden) } + + if let token = submittedToken { + guard token == requiredToken else { throw Abort(.forbidden) } + } else { + let _ = try request.content.decode([String: String].self).map(to: Void.self) { form in + guard let submittedToken: String = form[key] else { throw Abort(.forbidden) } + guard requiredToken == submittedToken else { throw Abort(.forbidden) } + } + } + } +} diff --git a/Sources/App/Utils/CSRF/CSRFViewContext.swift b/Sources/App/Utils/CSRF/CSRFViewContext.swift new file mode 100644 index 0000000..72b27c3 --- /dev/null +++ b/Sources/App/Utils/CSRF/CSRFViewContext.swift @@ -0,0 +1,5 @@ +import Foundation + +protocol CSRFViewContext { + var csrf: String { get set } +} diff --git a/Sources/App/Utils/CSRF/EmptyCSRFVerifier.swift b/Sources/App/Utils/CSRF/EmptyCSRFVerifier.swift new file mode 100644 index 0000000..f79836c --- /dev/null +++ b/Sources/App/Utils/CSRF/EmptyCSRFVerifier.swift @@ -0,0 +1,11 @@ +import Foundation +import Vapor + +struct EmptyCSRFVerifier: CSRF { + func verifyCSRF(submittedToken: String?, key: String, request: Request) throws { + + } + func setCSRF(key: String, request: Request) throws -> String { + return "" + } +} diff --git a/Sources/App/Utils/CSRF/Request+CSRF.swift b/Sources/App/Utils/CSRF/Request+CSRF.swift new file mode 100644 index 0000000..76361e6 --- /dev/null +++ b/Sources/App/Utils/CSRF/Request+CSRF.swift @@ -0,0 +1,16 @@ +import Foundation +import Vapor + +extension Request { + static let csrfKey = "csrf" + + func verifyCSRF(submittedToken: String? = nil, key: String = Request.csrfKey) throws { + let csrf = try make(CSRF.self) + return try csrf.verifyCSRF(submittedToken: submittedToken, key: key, request: self) + } + + func setCSRF(key: String = Request.csrfKey) throws -> String { + let csrf = try make(CSRF.self) + return try csrf.setCSRF(key: key, request: self) + } +} diff --git a/Tests/AppTests/Tests/Authed Tests/LoginTests.swift b/Tests/AppTests/Tests/Authed Tests/LoginTests.swift index e8e5818..c0d8898 100644 --- a/Tests/AppTests/Tests/Authed Tests/LoginTests.swift +++ b/Tests/AppTests/Tests/Authed Tests/LoginTests.swift @@ -1,6 +1,7 @@ import XCTest import Foundation import FluentMySQL +import Crypto @testable import Vapor @testable import App @@ -22,10 +23,23 @@ class LoginTests: XCTestCase { func testLinuxTestSuiteIncludesAllTests() { #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) let thisClass = type(of: self) -// let linuxCount = thisClass.__allTests.count - let linuxCount = 0 + let linuxCount = thisClass.__allTests.count let darwinCount = Int(thisClass.defaultTestSuite.testCaseCount) XCTAssertEqual(linuxCount, darwinCount, "\(darwinCount - linuxCount) tests are missing from allTests") #endif } + + /// Tests that a user with invalid credentials cannot login + func testLoginInvalidCredentials() throws { + let _ = try User(name: "name", email: "email@email.com", password: try BCrypt.hash("password")).save(on: conn).wait() + let loginRequest = LoginRequest(email: "email@email.com", password: "wrong password", csrf: "n/a") + + let loginResponse = try app.sendRequest(to: "/login", method: .POST, data: loginRequest, contentType: .json) + XCTAssertEqual(loginResponse.http.headers.firstValue(name: .location), "/login") + } + + /// Tests that a user with valid credentials can login + func testLoginSuccessful() throws { + + } } diff --git a/Tests/AppTests/Tests/Authed Tests/RegisterTests.swift b/Tests/AppTests/Tests/Authed Tests/RegisterTests.swift index 6c12c3c..ad8ccbf 100644 --- a/Tests/AppTests/Tests/Authed Tests/RegisterTests.swift +++ b/Tests/AppTests/Tests/Authed Tests/RegisterTests.swift @@ -22,10 +22,29 @@ class RegisterTests: XCTestCase { func testLinuxTestSuiteIncludesAllTests() { #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) let thisClass = type(of: self) - // let linuxCount = thisClass.__allTests.count - let linuxCount = 0 + let linuxCount = thisClass.__allTests.count let darwinCount = Int(thisClass.defaultTestSuite.testCaseCount) XCTAssertEqual(linuxCount, darwinCount, "\(darwinCount - linuxCount) tests are missing from allTests") #endif } + + /// Tests that an email cannot be registered twice + func testRegisterEmailAlreadyExists() throws { + + } + + /// Tests that an invalid email cannot register + func testInvalidEmailFails() throws { + + } + + /// Tests that passwords must match + func testRegisterPasswordsDontMatch() throws { + + } + + /// Tests that users can register successfully when meeting validation requirements + func testSuccessfulRegister() throws { + + } } diff --git a/Tests/AppTests/Utilities/Utilities.swift b/Tests/AppTests/Utilities/Utilities.swift index 7ea875e..96be565 100644 --- a/Tests/AppTests/Utilities/Utilities.swift +++ b/Tests/AppTests/Utilities/Utilities.swift @@ -27,7 +27,12 @@ extension Application { return MockLogger() } + services.register(CSRF.self) { _ -> EmptyCSRFVerifier in + return EmptyCSRFVerifier() + } + config.prefer(MockLogger.self, for: Logger.self) + config.prefer(EmptyCSRFVerifier.self, for: CSRF.self) let app = try Application(config: config, environment: env, services: services) diff --git a/Tests/AppTests/XCTestManifests.swift b/Tests/AppTests/XCTestManifests.swift new file mode 100644 index 0000000..1ed0758 --- /dev/null +++ b/Tests/AppTests/XCTestManifests.swift @@ -0,0 +1,28 @@ +import XCTest + +extension LoginTests { + static let __allTests = [ + ("testLinuxTestSuiteIncludesAllTests", testLinuxTestSuiteIncludesAllTests), + ("testLoginInvalidCredentials", testLoginInvalidCredentials), + ("testLoginSuccessful", testLoginSuccessful), + ] +} + +extension RegisterTests { + static let __allTests = [ + ("testLinuxTestSuiteIncludesAllTests", testLinuxTestSuiteIncludesAllTests), + ("testInvalidEmailFails", testInvalidEmailFails), + ("testRegisterEmailAlreadyExists", testRegisterEmailAlreadyExists), + ("testRegisterPasswordsDontMatch", testRegisterPasswordsDontMatch), + ("testSuccessfulRegister", testSuccessfulRegister), + ] +} + +#if !os(macOS) +public func __allTests() -> [XCTestCaseEntry] { + return [ + testCase(LoginTests.__allTests), + testCase(RegisterTests.__allTests), + ] +} +#endif diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index e69de29..177e5c7 100755 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -0,0 +1,8 @@ +import XCTest + +import AppTests + +var tests = [XCTestCaseEntry]() +tests += AppTests.__allTests() + +XCTMain(tests)