From e42966b5e682b71b60729c234b4ff1c05a3dae8b Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Sun, 1 Apr 2018 09:18:38 -0500 Subject: [PATCH] Integration test for sequences with Postgres. Refs #1555. --- docs/peewee/api.rst | 11 ++++++++--- peewee.py | 28 +++++++++++++++++++++------- tests/models.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/docs/peewee/api.rst b/docs/peewee/api.rst index 4a93bd0c2..22536a874 100644 --- a/docs/peewee/api.rst +++ b/docs/peewee/api.rst @@ -2841,9 +2841,11 @@ Schema Manager Execute CREATE TABLE query for the given model. - .. py:method:: drop_table([safe=True[, **options]]) + .. py:method:: drop_table([safe=True[, drop_sequences=True[, **options]]]) :param bool safe: Specify IF EXISTS clause. + :param bool drop_sequences: Drop any sequences associated with the + columns on the table (postgres only). :param options: Arbitrary options. Execute DROP TABLE query for the given model. @@ -2914,11 +2916,14 @@ Schema Manager Create sequence(s), index(es) and table for the model. - .. py:method:: drop_all([safe=True]) + .. py:method:: drop_all([safe=True[, drop_sequences=True[, **options]]]) :param bool safe: Whether to specify IF EXISTS. + :param bool drop_sequences: Drop any sequences associated with the + columns on the table (postgres only). + :param options: Arbitrary options. - Drop table for the model. + Drop table for the model and associated indexes. Model diff --git a/peewee.py b/peewee.py index 1d6686063..535d70153 100644 --- a/peewee.py +++ b/peewee.py @@ -4599,7 +4599,9 @@ def _create_sequence(self, field): .sql(Entity(field.sequence))) def create_sequence(self, field): - self.database.execute(self._create_sequence(field)) + seq_ctx = self._create_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) def _drop_sequence(self, field): self._check_sequences(field) @@ -4610,7 +4612,9 @@ def _drop_sequence(self, field): .sql(Entity(field.sequence))) def drop_sequence(self, field): - self.database.execute(self._drop_sequence(field)) + seq_ctx = self._drop_sequence(field) + if seq_ctx is not None: + self.database.execute(seq_ctx) def _create_foreign_key(self, field): name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, @@ -4628,17 +4632,27 @@ def _create_foreign_key(self, field): def create_foreign_key(self, field): self.database.execute(self._create_foreign_key(field)) - def create_all(self, safe=True, **table_options): + def create_sequences(self): if self.database.sequences: for field in self.model._meta.sorted_fields: - if field and field.sequence: + if field.sequence: self.create_sequence(field) + def create_all(self, safe=True, **table_options): + self.create_sequences() self.create_table(safe, **table_options) self.create_indexes(safe=safe) - def drop_all(self, safe=True, **options): + def drop_sequences(self): + if self.database.sequences: + for field in self.model._meta.sorted_fields: + if field.sequence: + self.drop_sequence(field) + + def drop_all(self, safe=True, drop_sequences=True, **options): self.drop_table(safe, **options) + if drop_sequences: + self.drop_sequences() class Metadata(object): @@ -5362,11 +5376,11 @@ def create_table(cls, safe=True, **options): cls._schema.create_all(safe, **options) @classmethod - def drop_table(cls, safe=True, **options): + def drop_table(cls, safe=True, drop_sequences=True, **options): if safe and not cls._meta.database.safe_drop_index \ and not cls.table_exists(): return - cls._schema.drop_all(safe, **options) + cls._schema.drop_all(safe, drop_sequences, **options) @classmethod def index(cls, *fields, **kwargs): diff --git a/tests/models.py b/tests/models.py index 832f2723a..9bde3f6d6 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2435,3 +2435,31 @@ def test_mixed_models_dict_row_type(self): {'content': 'walrus.txt', 'timestamp': self.ts(5)}, {'content': 'peewee.txt', 'timestamp': self.ts(4)}, {'content': 'note-c', 'timestamp': self.ts(3)}]) + + +class SequenceModel(TestModel): + seq_id = IntegerField(sequence='seq_id_sequence') + key = TextField() + + +@skip_case_unless(IS_POSTGRESQL) +class TestSequence(ModelTestCase): + requires = [SequenceModel] + + def test_create_table(self): + query = SequenceModel._schema._create_table() + self.assertSQL(query, ( + 'CREATE TABLE IF NOT EXISTS "sequencemodel" (' + '"id" SERIAL NOT NULL PRIMARY KEY, ' + '"seq_id" INTEGER NOT NULL DEFAULT NEXTVAL(\'seq_id_sequence\'), ' + '"key" TEXT NOT NULL)'), []) + + def test_sequence(self): + for key in ('k1', 'k2', 'k3'): + SequenceModel.create(key=key) + + s1, s2, s3 = SequenceModel.select().order_by(SequenceModel.key) + + self.assertEqual(s1.seq_id, 1) + self.assertEqual(s2.seq_id, 2) + self.assertEqual(s3.seq_id, 3)