Skip to content

Commit

Permalink
feat: BI-5696 csrf multiple secrets support added (#579)
Browse files Browse the repository at this point in the history
* csrf multiple secrets support added

* comment removed
  • Loading branch information
juliarbkv authored Aug 21, 2024
1 parent e85c4ab commit e3ce7f1
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 14 deletions.
9 changes: 5 additions & 4 deletions lib/dl_api_commons/dl_api_commons/aio/middlewares/csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CSRFMiddleware:

csrf_header_name: str = attr.ib()
csrf_time_limit: int = attr.ib()
csrf_secret: str = attr.ib()
csrf_secrets: tuple[str, ...] = attr.ib()
csrf_methods: tuple[str, ...] = attr.ib(default=("POST", "PUT", "DELETE"))

def validate_csrf_token(self, token_header_value: Optional[str], user_token: str) -> bool:
Expand All @@ -49,10 +49,11 @@ def validate_csrf_token(self, token_header_value: Optional[str], user_token: str
if ts_now - timestamp > self.csrf_time_limit:
return False

if not hmac.compare_digest(generate_csrf_token(user_token, timestamp, self.csrf_secret), token):
return False
for csrf_secret in self.csrf_secrets:
if hmac.compare_digest(generate_csrf_token(user_token, timestamp, csrf_secret), token):
return True

return True
return False

@web.middleware
@aiohttp_wrappers.DLRequestBase.use_dl_request_on_method
Expand Down
78 changes: 71 additions & 7 deletions lib/dl_api_commons/dl_api_commons_tests/unit/aio/test_csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
_SAMPLE_USER_ID = "123"
_SAMPLE_TIMESTAMP = 1
_VALID_CSRF_SECRET = "valid_secret"
_ANOTHER_VALID_CSRF_SECRET = "another_valid_secret"
_INVALID_CSRF_SECRET = "invalid_secret"
_CSRF_TIME_LIMIT = 3600 * 12

_AppFactory = Callable[[bool], Awaitable[TestClient]]
_AppFactory = Callable[[bool, tuple[str, ...]], Awaitable[TestClient]]


def ts_now():
Expand All @@ -45,7 +46,7 @@ class TestingCSRFMiddleware(CSRFMiddleware):

@pytest.fixture(scope="function")
async def csrf_app_factory(aiohttp_client) -> _AppFactory:
async def f(authorized: Optional[bool]) -> TestClient:
async def f(authorized: Optional[bool], secrets: tuple[str, ...]) -> TestClient:
async def non_csrf_handler(request: web.Request):
return web.json_response(dict(ok="ok"), status=200)

Expand All @@ -72,7 +73,7 @@ async def put(self) -> web.StreamResponse:
TestingCSRFMiddleware(
csrf_header_name="x-csrf-token",
csrf_time_limit=_CSRF_TIME_LIMIT,
csrf_secret=_VALID_CSRF_SECRET,
csrf_secrets=secrets,
csrf_methods=("POST", "PUT", "DELETE"),
).middleware,
]
Expand All @@ -89,7 +90,7 @@ async def put(self) -> web.StreamResponse:

@pytest.mark.asyncio
@pytest.mark.parametrize(
"case_name, method, authorized, headers, cookies",
"case_name, method, authorized, headers, cookies, secrets",
[
(
"OK_just_cookie",
Expand All @@ -101,6 +102,7 @@ async def put(self) -> web.StreamResponse:
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"OK_just_user_id",
Expand All @@ -112,6 +114,7 @@ async def put(self) -> web.StreamResponse:
)
},
{"does_not": "matter"},
(_VALID_CSRF_SECRET,),
),
(
"OK_cookie_and_user_id",
Expand All @@ -123,6 +126,7 @@ async def put(self) -> web.StreamResponse:
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"SKIP_non_csrf_method",
Expand All @@ -134,6 +138,7 @@ async def put(self) -> web.StreamResponse:
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"SKIP_no_cookies",
Expand All @@ -145,6 +150,7 @@ async def put(self) -> web.StreamResponse:
)
},
None,
(_VALID_CSRF_SECRET,),
),
(
"SKIP_no_user_tokens",
Expand All @@ -156,13 +162,51 @@ async def put(self) -> web.StreamResponse:
)
},
{"does_not": "matter"},
(_VALID_CSRF_SECRET,),
),
(
"SKIP_view_marker",
"PUT",
False,
{"x-csrf-token": "some_token"},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"OK_two_valid_secrets",
"POST",
True,
{
"x-csrf-token": "{}:{}".format(
generate_csrf_token(_SAMPLE_USER_ID, ts_now(), _VALID_CSRF_SECRET), ts_now()
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET, _ANOTHER_VALID_CSRF_SECRET),
),
(
"OK_two_valid_secrets_2",
"POST",
True,
{
"x-csrf-token": "{}:{}".format(
generate_csrf_token(_SAMPLE_USER_ID, ts_now(), _ANOTHER_VALID_CSRF_SECRET), ts_now()
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET, _ANOTHER_VALID_CSRF_SECRET),
),
(
"OK_multiple_secrets",
"POST",
True,
{
"x-csrf-token": "{}:{}".format(
generate_csrf_token(_SAMPLE_USER_ID, ts_now(), _ANOTHER_VALID_CSRF_SECRET), ts_now()
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_INVALID_CSRF_SECRET, _VALID_CSRF_SECRET, _ANOTHER_VALID_CSRF_SECRET),
),
],
)
Expand All @@ -172,9 +216,10 @@ async def test_csrf_ok(
authorized: bool,
headers: dict[str, str],
cookies: dict[str, str],
secrets: tuple[str, ...],
csrf_app_factory: _AppFactory,
):
client = await csrf_app_factory(authorized)
client = await csrf_app_factory(authorized, secrets)
resp = await client.request(
method=method,
path="/",
Expand All @@ -188,35 +233,39 @@ async def test_csrf_ok(

@pytest.mark.asyncio
@pytest.mark.parametrize(
"case_name, method, authorized, headers, cookies",
"case_name, method, authorized, headers, cookies, secrets",
[
(
"INVALID_no_csrf_token_provided",
"POST",
False,
{},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"INVALID_malformed_csrf_token_1",
"POST",
False,
{"x-csrf-token": "asdf:1234:5678"},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"INVALID_malformed_csrf_token_2",
"POST",
False,
{"x-csrf-token": "asdf:qwer"},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"INVALID_empty_token",
"POST",
False,
{"x-csrf-token": ":1234"},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"INVALID_token_expired",
Expand All @@ -229,6 +278,7 @@ async def test_csrf_ok(
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"INVALID_bad_secret",
Expand All @@ -240,6 +290,19 @@ async def test_csrf_ok(
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET,),
),
(
"INVALID_bad_secret_multiple_secrets",
"POST",
False,
{
"x-csrf-token": "{}:{}".format(
generate_csrf_token(_SAMPLE_USER_ID, ts_now(), _INVALID_CSRF_SECRET), ts_now()
)
},
{"user_id_cookie": _SAMPLE_USER_ID},
(_VALID_CSRF_SECRET, _ANOTHER_VALID_CSRF_SECRET),
),
],
)
Expand All @@ -249,12 +312,13 @@ async def test_csrf_invalid(
authorized: bool,
headers: dict[str, str],
cookies: dict[str, str],
secrets: tuple[str, ...],
csrf_app_factory: _AppFactory,
):
validation_failed_text = "CSRF validation failed"
validation_failed_status = 400

client = await csrf_app_factory(authorized)
client = await csrf_app_factory(authorized, secrets)
resp = await client.request(
method=method,
path="/",
Expand Down
2 changes: 1 addition & 1 deletion lib/dl_configs/dl_configs/settings_submodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CsrfSettings(SettingsBase):
METHODS: tuple[str, ...] = s_attrib("METHODS", env_var_converter=split_by_comma) # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "tuple[str, ...]") [assignment]
HEADER_NAME: str = s_attrib("HEADER_NAME") # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "str") [assignment]
TIME_LIMIT: int = s_attrib("TIME_LIMIT") # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "int") [assignment]
SECRET: str = s_attrib("SECRET", sensitive=True, missing=None) # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "str") [assignment]
SECRET: tuple[str, ...] = s_attrib("SECRET", sensitive=True, missing=None, env_var_converter=split_by_comma) # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "str") [assignment]


@attr.s(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def create_app(self, app_version: str) -> web.Application:
self.CSRF_MIDDLEWARE_CLS(
csrf_header_name=self._settings.CSRF.HEADER_NAME,
csrf_time_limit=self._settings.CSRF.TIME_LIMIT,
csrf_secret=self._settings.CSRF.SECRET,
csrf_secrets=self._settings.CSRF.SECRET,
csrf_methods=self._settings.CSRF.METHODS,
).middleware,
# TODO FIX: Add when json_body_middleware will be moved to dl_core
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class FileUploaderAPISettings(FileUploaderBaseSettings):
METHODS=cfg.CSRF_METHODS,
HEADER_NAME=cfg.CSRF_HEADER_NAME,
TIME_LIMIT=cfg.CSRF_TIME_LIMIT,
SECRET=required(str),
SECRET=required(tuple[str, ...]),
# TODO: move this values to a separate key
)
if is_setting_applicable(cfg, "CSRF_METHODS")
Expand Down

0 comments on commit e3ce7f1

Please sign in to comment.