Skip to content

Commit

Permalink
Integration test for sequences with Postgres.
Browse files Browse the repository at this point in the history
Refs #1555.
  • Loading branch information
coleifer committed Apr 1, 2018
1 parent e6012b2 commit e42966b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 10 deletions.
11 changes: 8 additions & 3 deletions docs/peewee/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2841,9 +2841,11 @@ Schema Manager

Execute CREATE TABLE query for the given model.

.. py:method:: drop_table([safe=True[, **options]])
.. py:method:: drop_table([safe=True[, drop_sequences=True[, **options]]])
:param bool safe: Specify IF EXISTS clause.
:param bool drop_sequences: Drop any sequences associated with the
columns on the table (postgres only).
:param options: Arbitrary options.

Execute DROP TABLE query for the given model.
Expand Down Expand Up @@ -2914,11 +2916,14 @@ Schema Manager

Create sequence(s), index(es) and table for the model.

.. py:method:: drop_all([safe=True])
.. py:method:: drop_all([safe=True[, drop_sequences=True[, **options]]])
:param bool safe: Whether to specify IF EXISTS.
:param bool drop_sequences: Drop any sequences associated with the
columns on the table (postgres only).
:param options: Arbitrary options.

Drop table for the model.
Drop table for the model and associated indexes.


Model
Expand Down
28 changes: 21 additions & 7 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -4599,7 +4599,9 @@ def _create_sequence(self, field):
.sql(Entity(field.sequence)))

def create_sequence(self, field):
self.database.execute(self._create_sequence(field))
seq_ctx = self._create_sequence(field)
if seq_ctx is not None:
self.database.execute(seq_ctx)

def _drop_sequence(self, field):
self._check_sequences(field)
Expand All @@ -4610,7 +4612,9 @@ def _drop_sequence(self, field):
.sql(Entity(field.sequence)))

def drop_sequence(self, field):
self.database.execute(self._drop_sequence(field))
seq_ctx = self._drop_sequence(field)
if seq_ctx is not None:
self.database.execute(seq_ctx)

def _create_foreign_key(self, field):
name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name,
Expand All @@ -4628,17 +4632,27 @@ def _create_foreign_key(self, field):
def create_foreign_key(self, field):
self.database.execute(self._create_foreign_key(field))

def create_all(self, safe=True, **table_options):
def create_sequences(self):
if self.database.sequences:
for field in self.model._meta.sorted_fields:
if field and field.sequence:
if field.sequence:
self.create_sequence(field)

def create_all(self, safe=True, **table_options):
self.create_sequences()
self.create_table(safe, **table_options)
self.create_indexes(safe=safe)

def drop_all(self, safe=True, **options):
def drop_sequences(self):
if self.database.sequences:
for field in self.model._meta.sorted_fields:
if field.sequence:
self.drop_sequence(field)

def drop_all(self, safe=True, drop_sequences=True, **options):
self.drop_table(safe, **options)
if drop_sequences:
self.drop_sequences()


class Metadata(object):
Expand Down Expand Up @@ -5362,11 +5376,11 @@ def create_table(cls, safe=True, **options):
cls._schema.create_all(safe, **options)

@classmethod
def drop_table(cls, safe=True, **options):
def drop_table(cls, safe=True, drop_sequences=True, **options):
if safe and not cls._meta.database.safe_drop_index \
and not cls.table_exists():
return
cls._schema.drop_all(safe, **options)
cls._schema.drop_all(safe, drop_sequences, **options)

@classmethod
def index(cls, *fields, **kwargs):
Expand Down
28 changes: 28 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,3 +2435,31 @@ def test_mixed_models_dict_row_type(self):
{'content': 'walrus.txt', 'timestamp': self.ts(5)},
{'content': 'peewee.txt', 'timestamp': self.ts(4)},
{'content': 'note-c', 'timestamp': self.ts(3)}])


class SequenceModel(TestModel):
seq_id = IntegerField(sequence='seq_id_sequence')
key = TextField()


@skip_case_unless(IS_POSTGRESQL)
class TestSequence(ModelTestCase):
requires = [SequenceModel]

def test_create_table(self):
query = SequenceModel._schema._create_table()
self.assertSQL(query, (
'CREATE TABLE IF NOT EXISTS "sequencemodel" ('
'"id" SERIAL NOT NULL PRIMARY KEY, '
'"seq_id" INTEGER NOT NULL DEFAULT NEXTVAL(\'seq_id_sequence\'), '
'"key" TEXT NOT NULL)'), [])

def test_sequence(self):
for key in ('k1', 'k2', 'k3'):
SequenceModel.create(key=key)

s1, s2, s3 = SequenceModel.select().order_by(SequenceModel.key)

self.assertEqual(s1.seq_id, 1)
self.assertEqual(s2.seq_id, 2)
self.assertEqual(s3.seq_id, 3)

0 comments on commit e42966b

Please sign in to comment.