From bc5b3dcb7ca1ce75c076384e718d66cc806d8e9c Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Wed, 10 Apr 2024 14:41:09 +0000 Subject: [PATCH] Reduce test flakiness --- fixbackend/auth/user_manager.py | 22 +++++++++++++++++++--- tests/fixbackend/auth/router_test.py | 14 ++++++-------- tests/fixbackend/conftest.py | 22 ++++++++++++++++++++++ 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/fixbackend/auth/user_manager.py b/fixbackend/auth/user_manager.py index 7cb20a4a..d85e752c 100644 --- a/fixbackend/auth/user_manager.py +++ b/fixbackend/auth/user_manager.py @@ -69,6 +69,8 @@ def __init__( self.workspace_repository = workspace_repository self.domain_events_publisher = domain_events_publisher self.invitation_repository = invitation_repository + self.custom_password_helper = password_helper is not None + self.otp_valid_window = 1 def parse_id(self, value: Any) -> UserId: if isinstance(value, UUID): @@ -219,6 +221,12 @@ async def oauth_associate_callback( return user async def compute_recovery_codes(self) -> Tuple[list[str], list[str]]: + # use custom password helper if provided, e.g. for testing + if self.custom_password_helper: + recovery_codes = [secrets.token_hex(16) for _ in range(10)] + hashes = [self.password_helper.hash(code) for code in recovery_codes] + return recovery_codes, hashes + # create recovery codes recovery_codes = [secrets.token_hex(16) for _ in range(10)] # create hashes of the recovery codes @@ -246,7 +254,7 @@ async def recreate_mfa(self, user: User) -> OTPConfig: async def enable_mfa(self, user: User, otp: str) -> bool: assert not user.is_mfa_active, "User already has MFA enabled." - if (secret := user.otp_secret) and not pyotp.TOTP(secret).verify(otp, valid_window=1): + if (secret := user.otp_secret) and not pyotp.TOTP(secret).verify(otp, valid_window=self.otp_valid_window): return False await self.user_repository.update(user, {"is_mfa_active": True}) return True @@ -263,15 +271,23 @@ async def check_otp(self, user: User, otp: Optional[str], recovery_code: Optiona if not user.is_mfa_active: return True if (secret := user.otp_secret) and (otp_defined := otp): - return pyotp.TOTP(secret).verify(otp_defined) + return pyotp.TOTP(secret).verify(otp_defined, valid_window=self.otp_valid_window) if recovery_code: return await self.user_repository.delete_recovery_code(user.id, recovery_code, self.password_helper) return False +def get_password_helper() -> PasswordHelperProtocol | None: + return None + + +PasswordHelperDependency = Annotated[PasswordHelperProtocol | None, Depends(get_password_helper)] + + async def get_user_manager( config: ConfigDependency, user_repository: UserRepositoryDependency, + password_helper: PasswordHelperDependency, user_verifier: AuthEmailSenderDependency, workspace_repository: WorkspaceRepositoryDependency, domain_event_publisher: DomainEventPublisherDependency, @@ -280,7 +296,7 @@ async def get_user_manager( yield UserManager( config, user_repository, - None, + password_helper, user_verifier, workspace_repository, domain_event_publisher, diff --git a/tests/fixbackend/auth/router_test.py b/tests/fixbackend/auth/router_test.py index 3dff7904..24fd80f2 100644 --- a/tests/fixbackend/auth/router_test.py +++ b/tests/fixbackend/auth/router_test.py @@ -12,7 +12,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import os from typing import Callable, List, Optional, Sequence, Tuple, override import jwt @@ -27,6 +26,7 @@ from fixbackend.auth.models import User from fixbackend.auth.models.orm import UserMFARecoveryCode from fixbackend.auth.schemas import OTPConfig +from fixbackend.auth.user_manager import get_password_helper from fixbackend.auth.user_repository import UserRepository from fixbackend.auth.user_verifier import AuthEmailSender, get_auth_email_sender from fixbackend.domain_events.dependencies import get_domain_event_publisher @@ -38,7 +38,7 @@ from fixbackend.workspaces.invitation_repository import InvitationRepository, get_invitation_repository from fixbackend.workspaces.models import WorkspaceInvitation from fixbackend.workspaces.repository import WorkspaceRepository -from tests.fixbackend.conftest import InMemoryDomainEventPublisher +from tests.fixbackend.conftest import InMemoryDomainEventPublisher, InsecureFastPasswordHelper from fixbackend.certificates.cert_store import CertificateStore @@ -122,9 +122,6 @@ async def remove_roles( @pytest.mark.asyncio -@pytest.mark.skipif( - os.getenv("LOCAL_DEV_ENV") is not None, reason="Skipping in local dev environment for performance reasons." -) async def test_registration_flow( api_client: AsyncClient, fast_api: FastAPI, @@ -132,6 +129,7 @@ async def test_registration_flow( workspace_repository: WorkspaceRepository, user_repository: UserRepository, cert_store: CertificateStore, + password_helper: InsecureFastPasswordHelper, ) -> None: verifier = InMemoryVerifier() invitation_repo = InMemoryInvitationRepo() @@ -140,6 +138,7 @@ async def test_registration_flow( fast_api.dependency_overrides[get_domain_event_publisher] = lambda: domain_event_sender fast_api.dependency_overrides[get_invitation_repository] = lambda: invitation_repo fast_api.dependency_overrides[get_role_repository] = lambda: role_repo + fast_api.dependency_overrides[get_password_helper] = lambda: password_helper registration_json = { "email": "user@example.com", @@ -214,14 +213,12 @@ async def test_registration_flow( @pytest.mark.asyncio -@pytest.mark.skipif( - os.getenv("LOCAL_DEV_ENV") is not None, reason="Skipping in local dev environment for performance reasons." -) async def test_mfa_flow( api_client: AsyncClient, fast_api: FastAPI, domain_event_sender: InMemoryDomainEventPublisher, user_repository: UserRepository, + password_helper: InsecureFastPasswordHelper, ) -> None: verifier = InMemoryVerifier() invitation_repo = InMemoryInvitationRepo() @@ -230,6 +227,7 @@ async def test_mfa_flow( fast_api.dependency_overrides[get_domain_event_publisher] = lambda: domain_event_sender fast_api.dependency_overrides[get_invitation_repository] = lambda: invitation_repo fast_api.dependency_overrides[get_role_repository] = lambda: role_repo + fast_api.dependency_overrides[get_password_helper] = lambda: password_helper # register user registration_json = {"email": "user2@example.com", "password": "changeme"} diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py index 28156731..ea739bac 100644 --- a/tests/fixbackend/conftest.py +++ b/tests/fixbackend/conftest.py @@ -13,12 +13,14 @@ # along with this program. If not, see . import asyncio +import hashlib import json import os from argparse import Namespace from asyncio import AbstractEventLoop from datetime import datetime, timezone from pathlib import Path +import random from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Sequence, Tuple, Optional from unittest.mock import patch @@ -39,6 +41,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy_utils import create_database, database_exists, drop_database +from fastapi_users.password import PasswordHelper from fixbackend.analytics import AnalyticsEventSender from fixbackend.analytics.events import AnalyticsEvent from fixbackend.app import fast_api_app @@ -688,6 +691,25 @@ async def cert_store(default_config: Config) -> CertificateStore: return CertificateStore(default_config) +class InsecureFastPasswordHelper(PasswordHelper): + def __init__(self) -> None: + pass + + def verify_and_update(self, plain_password: str, hashed_password: str) -> Tuple[bool, str]: + return hashed_password == hashlib.md5(plain_password.encode()).hexdigest(), hashed_password + + def hash(self, password: str) -> str: + return hashlib.md5(password.encode()).hexdigest() + + def generate(self) -> str: + return str(random.randint(100000, 999999)) + + +@pytest.fixture +def password_helper() -> InsecureFastPasswordHelper: + return InsecureFastPasswordHelper() + + @pytest.fixture async def fix_deps( db_engine: AsyncEngine,