diff --git a/firebird/introspection.py b/firebird/introspection.py index 7efe4de..4d31756 100644 --- a/firebird/introspection.py +++ b/firebird/introspection.py @@ -186,12 +186,12 @@ def get_indexes(self, cursor, table_name): cursor.execute(""" SELECT LOWER(s.RDB$FIELD_NAME) AS field_name, - + LOWER(case when rc.RDB$CONSTRAINT_TYPE is not null then rc.RDB$CONSTRAINT_TYPE else 'INDEX' end) AS constraint_type - + FROM RDB$INDEX_SEGMENTS s LEFT JOIN RDB$INDICES i ON i.RDB$INDEX_NAME = s.RDB$INDEX_NAME LEFT JOIN RDB$RELATION_CONSTRAINTS rc ON rc.RDB$INDEX_NAME = s.RDB$INDEX_NAME @@ -250,7 +250,7 @@ def get_constraints(self, cursor, table_name): when s.RDB$FIELD_NAME is not null then s.RDB$FIELD_NAME else '' end AS field_name, - + i2.RDB$RELATION_NAME AS references_table, s2.RDB$FIELD_NAME AS references_field, i.RDB$UNIQUE_FLAG, @@ -329,6 +329,21 @@ def _get_field_indexes(self, cursor, table_name, field_name): return [index_name[0].strip() for index_name in cursor.fetchall()] + def _get_check_constraints(self, cursor, table_name, field_name): + table = "'%s'" % table_name.upper() + field = "'%s'" % field_name.upper() + cursor.execute(""" + select distinct rc.rdb$constraint_name as check_constraint_name + from rdb$relation_constraints rc + join rdb$check_constraints cc on rc.rdb$constraint_name = cc.rdb$constraint_name + join rdb$triggers c on cc.rdb$trigger_name = c.rdb$trigger_name + join rdb$relation_fields rf on rc.rdb$relation_name = rf.rdb$relation_name + where rc.rdb$relation_name = %s + and rf.rdb$field_name = %s + and rc.rdb$constraint_type = 'CHECK' + """ % (table, field,)) + return [cn[0].strip() for cn in cursor.fetchall()] + def _name_to_index(self, cursor, table_name): """Return a dictionary of {field_name: field_index} for the given table. Indexes are 0-based. diff --git a/firebird/schema.py b/firebird/schema.py index e37678d..0071bbb 100644 --- a/firebird/schema.py +++ b/firebird/schema.py @@ -47,8 +47,8 @@ def _alter_column_set_null(self, table_name, column_name, is_null): engine_ver = str(self.connection.connection.engine_version).split('.') if engine_ver and len(engine_ver) > 0 and int(engine_ver[0]) >= 3: sql = """ - ALTER TABLE \"%(table_name)s\" - ALTER \"%(column)s\" + ALTER TABLE \"%(table_name)s\" + ALTER \"%(column)s\" %(null_flag)s NOT NULL """ null_flag = 'DROP' if is_null else 'SET' @@ -237,6 +237,12 @@ def _get_field_indexes(self, model, field): indexes = self.connection.introspection._get_field_indexes(cursor, model._meta.db_table, field.column) return indexes + def _get_field_check_constraints(self, model, field): + with self.connection.cursor() as cursor: + db_table = model._meta.db_table + checks = self.connection.introspection._get_check_constraints(cursor, db_table, field.column) + return checks + def remove_field(self, model, field): # If remove a AutoField, we need remove all related stuff # if isinstance(field, AutoField): @@ -252,6 +258,11 @@ def remove_field(self, model, field): sql = self._delete_constraint_sql(self.sql_delete_index, model, index_name) self.execute(sql) + # If field has check constraint, then remove it first + for check_constraint_name in self._get_field_check_constraints(model, field): + sql = self._delete_constraint_sql(self.sql_delete_constraint, model, check_constraint_name) + self.execute(sql) + super(DatabaseSchemaEditor, self).remove_field(model, field) def _alter_column_type_sql(self, table, old_field, new_field, new_type): @@ -875,4 +886,4 @@ def execute(self, sql, params=[]): except Exception as e: raise e else: - super(DatabaseSchemaEditor, self).execute(sql, params) \ No newline at end of file + super(DatabaseSchemaEditor, self).execute(sql, params)