diff --git a/django_scrubber/management/commands/scrub_data.py b/django_scrubber/management/commands/scrub_data.py index 7880250..b549203 100644 --- a/django_scrubber/management/commands/scrub_data.py +++ b/django_scrubber/management/commands/scrub_data.py @@ -7,6 +7,7 @@ from django.contrib.sessions.models import Session from django.core.exceptions import FieldDoesNotExist from django.core.management.base import BaseCommand, CommandError +from django.db import DEFAULT_DB_ALIAS from django.db.models import F from django.db.utils import IntegrityError, DataError @@ -29,6 +30,10 @@ def add_arguments(self, parser): help='Will truncate the database table storing preprocessed data for the Faker library. ' 'If you want to do multiple iterations of scrubbing, it will save you time to keep ' 'them. If not, you will add a huge bunch of data to your dump size.') + parser.add_argument('--database', default=DEFAULT_DB_ALIAS, + help='Nominates a database to scrub. Defaults to the "default" database.') + + def handle(self, *args, **kwargs): if not settings.DEBUG: @@ -36,6 +41,8 @@ def handle(self, *args, **kwargs): self.stderr.write('this command should only be run with DEBUG=True, to avoid running on live systems') return False + database = kwargs['database'] + global_scrubbers = settings_with_fallback('SCRUBBER_GLOBAL_SCRUBBERS') # run only for selected model @@ -78,7 +85,7 @@ def handle(self, *args, **kwargs): try: model.objects.annotate( mod_pk=F('pk') % settings_with_fallback('SCRUBBER_ENTRIES_PER_PROVIDER') - ).update(**realized_scrubbers) + ).using(database).update(**realized_scrubbers) except IntegrityError as e: raise CommandError('Integrity error while scrubbing %s (%s); maybe increase ' 'SCRUBBER_ENTRIES_PER_PROVIDER?' % (model, e)) @@ -87,11 +94,11 @@ def handle(self, *args, **kwargs): # Truncate session data if not kwargs.get('keep_sessions', False): - Session.objects.all().delete() + Session.objects.using(database).all().delete() # Truncate Faker data if kwargs.get('remove_fake_data', False): - FakeData.objects.all().delete() + FakeData.objects.using(database).all().delete() def _call_callables(d): diff --git a/tests/models.py b/tests/models.py index 23bdc81..6a78e95 100644 --- a/tests/models.py +++ b/tests/models.py @@ -15,3 +15,9 @@ class DataToBeScrubbed(models.Model): class DataFactory(DjangoModelFactory): class Meta: model = DataToBeScrubbed + + +class OtherDatabaseDataFactory(DjangoModelFactory): + class Meta: + model = DataToBeScrubbed + database = 'other' diff --git a/tests/settings.py b/tests/settings.py index 497ab4a..963d725 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -14,6 +14,12 @@ 'NAME': ':memory:', 'OPTIONS': { } + }, + 'other': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:', + 'OPTIONS': { + } } } diff --git a/tests/test_scrub_data.py b/tests/test_scrub_data.py index 9799b66..e5d9e6e 100644 --- a/tests/test_scrub_data.py +++ b/tests/test_scrub_data.py @@ -13,14 +13,25 @@ User = get_user_model() +class BaseDatabaseTestCase(TestCase): + databases = ['default'] + + def _db(self): + # the frameworks transforms self.databases into a set... + return list(self.databases)[0] + + +class OtherDatabaseMixin: + databases = ['other'] + +class TestScrubData(BaseDatabaseTestCase): -class TestScrubData(TestCase): def setUp(self): - self.user = User.objects.create(first_name='test_first_name') + self.user = User.objects.using(self._db()).create(first_name='test_first_name') def test_scrub_data(self): with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'first_name': scrubbers.Faker('first_name')}): - call_command('scrub_data', verbosity=3) + call_command('scrub_data', verbosity=3, database=self._db()) self.user.refresh_from_db() self.assertNotEqual(self.user.first_name, 'test_first_name') @@ -29,7 +40,7 @@ def test_scrub_data_debug_is_false(self): err = StringIO() with self.settings(DEBUG=False): - call_command('scrub_data', stderr=err) + call_command('scrub_data', stderr=err, database=self._db()) output = err.getvalue() self.user.refresh_from_db() @@ -38,7 +49,7 @@ def test_scrub_data_debug_is_false(self): def test_hash_simple_global_scrubber(self): with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'first_name': scrubbers.Hash}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) self.user.refresh_from_db() self.assertNotEqual(self.user.first_name, 'test_first_name') @@ -48,7 +59,7 @@ class Scrubbers: first_name = scrubbers.Hash with self.settings(DEBUG=True), patch.object(User, 'Scrubbers', Scrubbers, create=True): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) self.user.refresh_from_db() self.assertNotEqual(self.user.first_name, 'test_first_name') @@ -61,7 +72,7 @@ class Scrubbers: with self.assertWarnsRegex( Warning, 'Scrubber defined for User.this_does_not_exist_382784 but field does not exist' ): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) @override_settings(SCRUBBER_MAPPING={"auth.User": "tests.scrubbers.UserScrubbers"}) def test_get_model_scrubbers_mapper_from_settings_used(self): @@ -83,3 +94,7 @@ def test_parse_scrubber_class_from_string_wrong_path(self): def test_parse_scrubber_class_from_string_path_no_separator(self): with self.assertRaises(ImportError): _parse_scrubber_class_from_string('broken_path') + + +class TestScrubDataOnOtherDatabase(OtherDatabaseMixin, TestScrubData): + pass diff --git a/tests/test_scrubbers.py b/tests/test_scrubbers.py index 5b0e5ef..5cae3ce 100644 --- a/tests/test_scrubbers.py +++ b/tests/test_scrubbers.py @@ -11,12 +11,16 @@ from django_scrubber import scrubbers from django_scrubber.models import FakeData -from .models import DataFactory, DataToBeScrubbed +from .models import DataFactory, OtherDatabaseDataFactory, DataToBeScrubbed +from .test_scrub_data import BaseDatabaseTestCase, OtherDatabaseMixin +class TestScrubbers(BaseDatabaseTestCase): + + def _data_factory(self): + return DataFactory if self._db() == 'default' else OtherDatabaseDataFactory -class TestScrubbers(TestCase): def test_empty_scrubber(self): - data = DataFactory.create(first_name='Foo') + data = self._data_factory().create(first_name='Foo') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'first_name': scrubbers.Empty}): call_command('scrub_data') data.refresh_from_db() @@ -24,17 +28,17 @@ def test_empty_scrubber(self): self.assertEqual(data.first_name, '') def test_null_scrubber(self): - data = DataFactory.create(last_name='Foo') + data = self._data_factory().create(last_name='Foo') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'last_name': scrubbers.Null}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() self.assertEqual(data.last_name, None) def test_hash_scrubber_max_length(self): - data = DataFactory.create(first_name='Foo') + data = self._data_factory().create(first_name='Foo') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'first_name': scrubbers.Hash}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() self.assertNotEqual(data.first_name, 'Foo') @@ -45,26 +49,26 @@ def test_hash_scrubber_max_length(self): ) def test_hash_scrubber_textfield(self): - data = DataFactory.create(description='Foo') + data = self._data_factory().create(description='Foo') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'description': scrubbers.Hash}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() self.assertNotEqual(data.description, 'Foo') def test_lorem_scrubber(self): - data = DataFactory.create(description='Foo') + data = self._data_factory().create(description='Foo') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'description': scrubbers.Lorem}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() self.assertNotEqual(data.description, 'Foo') self.assertEqual(data.description[:11], 'Lorem ipsum') def test_faker_scrubber_charfield(self): - data = DataFactory.create(last_name='Foo') + data = self._data_factory().create(last_name='Foo') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'last_name': scrubbers.Faker('last_name')}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() self.assertNotEqual(data.last_name, 'Foo') @@ -74,9 +78,9 @@ def test_faker_scrubber_with_provider_arguments(self): """ Use this as an example for Faker scrubbers with parameters passed along """ - data = DataFactory.create(ean8='8') + data = self._data_factory().create(ean8='8') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'ean8': scrubbers.Faker('ean', length=8)}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() # The EAN Faker will by default emit ean13, so this would fail if the parameter was ignored @@ -84,7 +88,7 @@ def test_faker_scrubber_with_provider_arguments(self): # Add a new scrubber for ean13 with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={'ean8': scrubbers.Faker('ean', length=13)}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() # make sure it doesn't reuse the ean with length=8 scrubber @@ -96,10 +100,10 @@ def test_faker_scrubber_datefield(self): There is a bug with django < 2.1 and sqlite, that's why we don't run the test there. """ if django.VERSION >= (2, 1) or connection.vendor != "sqlite": - data = DataFactory.create(date_past=date.today()) + data = self._data_factory().create(date_past=date.today()) with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={ 'date_past': scrubbers.Faker('past_date', start_date="-30d", tzinfo=None)}): - call_command('scrub_data') + call_command('scrub_data', database=self._db()) data.refresh_from_db() self.assertGreater(date.today(), data.date_past) @@ -109,11 +113,11 @@ def test_faker_scrubber_run_twice(self): """ Use this as an example of what happens when you want to run the same Faker scrubbers twice """ - data = DataFactory.create(company='Foo') + data = self._data_factory().create(company='Foo') with self.settings(DEBUG=True, SCRUBBER_GLOBAL_SCRUBBERS={ 'company': scrubbers.Faker('company')}): - call_command('scrub_data') - call_command('scrub_data') + call_command('scrub_data', database=self._db()) + call_command('scrub_data', database=self._db()) data.refresh_from_db() self.assertNotEqual(data.company, 'Foo') @@ -125,16 +129,16 @@ def test_faker_scrubber_run_clear_session_by_default(self): Ensures that the session table will be emptied by default """ # Create session object - Session.objects.create(session_key='foo', session_data='Lorem ipsum', expire_date=timezone.now()) + Session.objects.using(self._db()).create(session_key='foo', session_data='Lorem ipsum', expire_date=timezone.now()) # Sanity check - self.assertTrue(Session.objects.all().exists()) + self.assertTrue(Session.objects.using(self._db()).all().exists()) # Call command - call_command('scrub_data') + call_command('scrub_data', database=self._db()) # Assertion that session table is empty now - self.assertFalse(Session.objects.all().exists()) + self.assertFalse(Session.objects.using(self._db()).all().exists()) @override_settings(DEBUG=True) def test_faker_scrubber_run_disable_session_clearing(self): @@ -142,16 +146,16 @@ def test_faker_scrubber_run_disable_session_clearing(self): Ensures that the session table will be emptied by default """ # Create session object - Session.objects.create(session_key='foo', session_data='Lorem ipsum', expire_date=timezone.now()) + Session.objects.using(self._db()).create(session_key='foo', session_data='Lorem ipsum', expire_date=timezone.now()) # Sanity check - self.assertTrue(Session.objects.all().exists()) + self.assertTrue(Session.objects.using(self._db()).all().exists()) # Call command - call_command('scrub_data', keep_sessions=True) + call_command('scrub_data', keep_sessions=True, database=self._db()) # Assertion that session table is empty now - self.assertTrue(Session.objects.all().exists()) + self.assertTrue(Session.objects.using(self._db()).all().exists()) @override_settings(DEBUG=True) def test_faker_scrubber_run_clear_faker_data_not_by_default(self): @@ -159,13 +163,13 @@ def test_faker_scrubber_run_clear_faker_data_not_by_default(self): Ensures that the session table will be emptied by default """ # Create faker data object - FakeData.objects.create(provider='company', content='Foo', provider_offset=1) + FakeData.objects.using(self._db()).create(provider='company', content='Foo', provider_offset=1) # Sanity check self.assertTrue(FakeData.objects.filter(provider='company', content='Foo').exists()) # Call command - call_command('scrub_data') + call_command('scrub_data', database=self._db()) # Assertion that faker data still exists self.assertTrue(FakeData.objects.filter(provider='company', content='Foo').exists()) @@ -176,13 +180,16 @@ def test_faker_scrubber_run_clear_faker_data_works(self): Ensures that the session table will be emptied by default """ # Create faker data object - FakeData.objects.create(provider='company', content='Foo', provider_offset=1) + FakeData.objects.using(self._db()).create(provider='company', content='Foo', provider_offset=1) # Sanity check self.assertTrue(FakeData.objects.filter(provider='company', content='Foo').exists()) # Call command - call_command('scrub_data', remove_fake_data=True) + call_command('scrub_data', remove_fake_data=True, database=self._db()) # Assertion that faker data still exists self.assertFalse(FakeData.objects.filter(provider='company', content='Foo').exists()) + +class TestScrubbersOnOtherDatabase(OtherDatabaseMixin, TestScrubbers): + pass