Skip to content

Commit

Permalink
add csrf and first test
Browse files Browse the repository at this point in the history
  • Loading branch information
jdmcd committed Jun 15, 2018
1 parent adf8019 commit d01330d
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ class LoginViewController: RouteCollection {
}

func login(req: Request, content: LoginRequest) throws -> Future<Response> {
guard let correctCSRFKey = try req.session()["csrf"] else { throw Abort(.badRequest) }
let submittedKey = content.csrf

guard correctCSRFKey == submittedKey else { throw Abort(.unauthorized) }
try req.session()["csrf"] = nil
try req.verifyCSRF()

let userQuery = User.query(on: req)
.filter(\.email == content.email)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class RegisterViewController: RouteCollection {
}

func register(req: Request, content: RegisterRequest) throws -> Future<Response> {
try req.verifyCSRF()

guard content.password == content.confirmPassword else { throw RedirectError(to: "/register", error: "Passwords don't match") }

let existingUserQuery = User
Expand Down
16 changes: 16 additions & 0 deletions Sources/App/Setup/configure.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Redis
import VaporSecurityHeaders
import URLEncodedForm
import Authentication
import Flash

public func configure(_ config: inout Config, _ env: inout Environment, _ services: inout Services) throws {
/// Register providers first
Expand Down Expand Up @@ -63,6 +64,12 @@ public func configure(_ config: inout Config, _ env: inout Environment, _ servic
return KeyedCacheSessions(keyedCache: keyedCache)
}

services.register(CSRF.self) { _ -> CSRFVerifier in
return CSRFVerifier()
}

config.prefer(CSRFVerifier.self, for: CSRF.self)

/// Setup Security Headers
let cspConfig = ContentSecurityPolicyConfiguration(value: CSPConfig.setupCSP().generateString())
let xssProtectionConfig = XSSProtectionConfiguration(option: .block)
Expand Down Expand Up @@ -116,4 +123,13 @@ public func configure(_ config: inout Config, _ env: inout Environment, _ servic
/// Register KeyStorage
guard let apiKey = Environment.get(Constants.restMiddlewareEnvKey) else { throw Abort(.internalServerError) }
services.register(KeyStorage(restMiddlewareApiKey: apiKey))

//Leaf Tag Config
var defaultTags = LeafTagConfig.default()
defaultTags.use(FlashTag(), as: "flash")

services.register(defaultTags)

/// Flash Provider
try services.register(FlashProvider())
}
60 changes: 0 additions & 60 deletions Sources/App/Utils/CSRF.swift

This file was deleted.

15 changes: 15 additions & 0 deletions Sources/App/Utils/CSRF/CSRF.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import Foundation
import Crypto
import Vapor

protocol CSRF: Service {
func verifyCSRF(submittedToken: String?, key: String, request: Request) throws
func setCSRF(key: String, request: Request) throws -> String
func generateRandom() throws -> String
}

extension CSRF {
func generateRandom() throws -> String {
return try CryptoRandom().generateData(count: 4).hexEncodedString()
}
}
10 changes: 10 additions & 0 deletions Sources/App/Utils/CSRF/CSRFContext.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Foundation

struct CSRFContext: CSRFViewContext, ViewContext {
var common: CommonViewContext?
var csrf: String

init(csrf: String) {
self.csrf = csrf
}
}
24 changes: 24 additions & 0 deletions Sources/App/Utils/CSRF/CSRFVerifier.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import Foundation
import Vapor

struct CSRFVerifier: CSRF {
func setCSRF(key: String, request: Request) throws -> String {
let string = "\(try generateRandom())-\(try generateRandom())-\(try generateRandom())-\(try generateRandom())"
try request.session()[key] = string

return string
}

func verifyCSRF(submittedToken: String?, key: String, request: Request) throws {
guard let requiredToken: String = try request.session()[key] else { throw Abort(.forbidden) }

if let token = submittedToken {
guard token == requiredToken else { throw Abort(.forbidden) }
} else {
let _ = try request.content.decode([String: String].self).map(to: Void.self) { form in
guard let submittedToken: String = form[key] else { throw Abort(.forbidden) }
guard requiredToken == submittedToken else { throw Abort(.forbidden) }
}
}
}
}
5 changes: 5 additions & 0 deletions Sources/App/Utils/CSRF/CSRFViewContext.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import Foundation

protocol CSRFViewContext {
var csrf: String { get set }
}
11 changes: 11 additions & 0 deletions Sources/App/Utils/CSRF/EmptyCSRFVerifier.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import Foundation
import Vapor

struct EmptyCSRFVerifier: CSRF {
func verifyCSRF(submittedToken: String?, key: String, request: Request) throws {

}
func setCSRF(key: String, request: Request) throws -> String {
return ""
}
}
16 changes: 16 additions & 0 deletions Sources/App/Utils/CSRF/Request+CSRF.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import Foundation
import Vapor

extension Request {
static let csrfKey = "csrf"

func verifyCSRF(submittedToken: String? = nil, key: String = Request.csrfKey) throws {
let csrf = try make(CSRF.self)
return try csrf.verifyCSRF(submittedToken: submittedToken, key: key, request: self)
}

func setCSRF(key: String = Request.csrfKey) throws -> String {
let csrf = try make(CSRF.self)
return try csrf.setCSRF(key: key, request: self)
}
}
18 changes: 16 additions & 2 deletions Tests/AppTests/Tests/Authed Tests/LoginTests.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import XCTest
import Foundation
import FluentMySQL
import Crypto
@testable import Vapor
@testable import App

Expand All @@ -22,10 +23,23 @@ class LoginTests: XCTestCase {
func testLinuxTestSuiteIncludesAllTests() {
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
let thisClass = type(of: self)
// let linuxCount = thisClass.__allTests.count
let linuxCount = 0
let linuxCount = thisClass.__allTests.count
let darwinCount = Int(thisClass.defaultTestSuite.testCaseCount)
XCTAssertEqual(linuxCount, darwinCount, "\(darwinCount - linuxCount) tests are missing from allTests")
#endif
}

/// Tests that a user with invalid credentials cannot login
func testLoginInvalidCredentials() throws {
let _ = try User(name: "name", email: "[email protected]", password: try BCrypt.hash("password")).save(on: conn).wait()
let loginRequest = LoginRequest(email: "[email protected]", password: "wrong password", csrf: "n/a")

let loginResponse = try app.sendRequest(to: "/login", method: .POST, data: loginRequest, contentType: .json)
XCTAssertEqual(loginResponse.http.headers.firstValue(name: .location), "/login")
}

/// Tests that a user with valid credentials can login
func testLoginSuccessful() throws {

}
}
23 changes: 21 additions & 2 deletions Tests/AppTests/Tests/Authed Tests/RegisterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,29 @@ class RegisterTests: XCTestCase {
func testLinuxTestSuiteIncludesAllTests() {
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
let thisClass = type(of: self)
// let linuxCount = thisClass.__allTests.count
let linuxCount = 0
let linuxCount = thisClass.__allTests.count
let darwinCount = Int(thisClass.defaultTestSuite.testCaseCount)
XCTAssertEqual(linuxCount, darwinCount, "\(darwinCount - linuxCount) tests are missing from allTests")
#endif
}

/// Tests that an email cannot be registered twice
func testRegisterEmailAlreadyExists() throws {

}

/// Tests that an invalid email cannot register
func testInvalidEmailFails() throws {

}

/// Tests that passwords must match
func testRegisterPasswordsDontMatch() throws {

}

/// Tests that users can register successfully when meeting validation requirements
func testSuccessfulRegister() throws {

}
}
5 changes: 5 additions & 0 deletions Tests/AppTests/Utilities/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ extension Application {
return MockLogger()
}

services.register(CSRF.self) { _ -> EmptyCSRFVerifier in
return EmptyCSRFVerifier()
}

config.prefer(MockLogger.self, for: Logger.self)
config.prefer(EmptyCSRFVerifier.self, for: CSRF.self)

let app = try Application(config: config, environment: env, services: services)

Expand Down
28 changes: 28 additions & 0 deletions Tests/AppTests/XCTestManifests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import XCTest

extension LoginTests {
static let __allTests = [
("testLinuxTestSuiteIncludesAllTests", testLinuxTestSuiteIncludesAllTests),
("testLoginInvalidCredentials", testLoginInvalidCredentials),
("testLoginSuccessful", testLoginSuccessful),
]
}

extension RegisterTests {
static let __allTests = [
("testLinuxTestSuiteIncludesAllTests", testLinuxTestSuiteIncludesAllTests),
("testInvalidEmailFails", testInvalidEmailFails),
("testRegisterEmailAlreadyExists", testRegisterEmailAlreadyExists),
("testRegisterPasswordsDontMatch", testRegisterPasswordsDontMatch),
("testSuccessfulRegister", testSuccessfulRegister),
]
}

#if !os(macOS)
public func __allTests() -> [XCTestCaseEntry] {
return [
testCase(LoginTests.__allTests),
testCase(RegisterTests.__allTests),
]
}
#endif
8 changes: 8 additions & 0 deletions Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import XCTest

import AppTests

var tests = [XCTestCaseEntry]()
tests += AppTests.__allTests()

XCTMain(tests)

0 comments on commit d01330d

Please sign in to comment.