Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to ignore csv fields as requested in #15 #30

Closed
wants to merge 13 commits into from
57 changes: 33 additions & 24 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,31 +143,40 @@ Argument Description
string field names for the CSV header.
================= =========================================================

===================== =====================================================
======================= =====================================================
Keyword Arguments
===================== =====================================================
``delimiter`` The character that separates values in the data file.
By default it is ",". This must be a single one-byte
character.

``null`` Specifies the string that represents a null value.
The default is an unquoted empty string. This must
be a single one-byte character.

``encoding`` Specifies the character set encoding of the strings
in the CSV data source. For example, ``'latin-1'``,
``'utf-8'``, and ``'cp437'`` are all valid encoding
parameters.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
database.

``static_mapping`` Set model attributes not in the CSV the same
for every row in the database by providing a dictionary
with the name of the columns as keys and the static
inputs as values.
===================== =====================================================
======================= =====================================================
``delimiter`` The character that separates values in the data file.
By default it is ",". This must be a single one-byte
character.

``null`` Specifies the string that represents a null value.
The default is an unquoted empty string. This must
be a single one-byte character.

``encoding`` Specifies the character set encoding of the strings
in the CSV data source. For example, ``'latin-1'``,
``'utf-8'``, and ``'cp437'`` are all valid encoding
parameters.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
database.

``static_mapping`` Set model attributes not in the CSV the same
for every row in the database by providing a dictionary
with the name of the columns as keys and the static
inputs as values.

``ignore_headers`` A list of headers from your csv that don't have
equivalent fields in your model. These columns will
be ignored.

``overloaded_mapping`` Reuse a mapped column for a different model field.
This is useful when you want to have both the
original value as well as a modified form, generally
using a `copy_template` to transform the second value
======================= =====================================================


``save()`` keyword arguments
Expand Down
73 changes: 58 additions & 15 deletions postgres_copy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
import csv
import os
import sys
import csv
from django.db import connections, router
from django.contrib.humanize.templatetags.humanize import intcomma
from collections import OrderedDict

from django.contrib.humanize.templatetags.humanize import intcomma
from django.db import connections, router


class CopyMapping(object):
"""
Maps comma-delimited data file to a Django model
and loads it into PostgreSQL databases using its
COPY command.
"""

def __init__(
self,
model,
csv_path,
mapping,
ignore_headers=None,
using=None,
delimiter=',',
null=None,
encoding=None,
static_mapping=None
static_mapping=None,
overloaded_mapping=None
):
self.model = model
self.mapping = mapping
Expand All @@ -37,34 +41,54 @@ def __init__(
if self.conn.vendor != 'postgresql':
raise TypeError("Only PostgreSQL backends supported")
self.backend = self.conn.ops
if ignore_headers is None:
self.ignore_headers = []
else:
self.ignore_headers = ignore_headers
self.delimiter = delimiter
self.null = null
self.encoding = encoding
if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
self.static_mapping = {}
if overloaded_mapping is not None:
self.overloaded_mapping = overloaded_mapping
else:
self.overloaded_mapping = {}

# Connect the headers from the CSV with the fields on the model
self.field_header_crosswalk = []
inverse_mapping = {v: k for k, v in self.mapping.items()}
for ignore in self.ignore_headers:
inverse_mapping[ignore] = ignore.lower()
for h in self.get_headers():
try:
f_name = inverse_mapping[h]
except KeyError:
raise ValueError("Map does not include %s field" % h)
try:
f = [f for f in self.model._meta.fields if f.name == f_name][0]
if f_name not in [ih.lower() for ih in self.ignore_headers]:
f = [f for f in self.model._meta.fields
if f.name == f_name][0]
except IndexError:
raise ValueError("Model does not include %s field" % f_name)
self.field_header_crosswalk.append((f, h))

# Validate that the static mapping columns exist
for f_name in self.static_mapping.keys():
try:
[s for s in self.model._meta.fields if s.name == f_name][0]
except IndexError:
raise ValueError("Model does not include %s field" % f_name)
# Validate Overloaded headers and fields
self.overloaded_crosswalk = []
for k, v in self.overloaded_mapping.items():
try:
o = [o for o in self.model._meta.fields if o.name == k][0]
self.overloaded_crosswalk.append((o, v))
except IndexError:
raise ValueError("Model does not include overload %s field"
% v)

self.temp_table_name = "temp_%s" % self.model._meta.db_table

Expand Down Expand Up @@ -173,7 +197,7 @@ def prep_copy(self):
'extra_options': '',
'header_list': ", ".join([
'"%s"' % h for f, h in self.field_header_crosswalk
])
])
}
if self.delimiter:
options['extra_options'] += " DELIMITER '%s'" % self.delimiter
Expand Down Expand Up @@ -204,24 +228,43 @@ def prep_insert(self):
model_fields = []

for field, header in self.field_header_crosswalk:
if header in self.ignore_headers:
continue
model_fields.append('"%s"' % field.get_attname_column()[1])

for k in self.static_mapping.keys():
model_fields.append('"%s"' % k)

for field, value in self.overloaded_crosswalk:
model_fields.append('"%s"' % field.get_attname_column()[1])

options['model_fields'] = ", ".join(model_fields)

temp_fields = []
for field, header in self.field_header_crosswalk:
string = '"%s"' % header
if hasattr(field, 'copy_template'):
string = field.copy_template % dict(name=header)
template_method = 'copy_%s_template' % field.name
if hasattr(self.model, template_method):
template = getattr(self.model(), template_method)()
string = template % dict(name=header)
temp_fields.append(string)
if header in self.ignore_headers:
continue
temp_fields.append(self._generate_insert_temp_fields(
field, header)
)

for v in self.static_mapping.values():
temp_fields.append("'%s'" % v)

for field, value in self.overloaded_crosswalk:
temp_fields.append(self._generate_insert_temp_fields(
field, value)
)
options['temp_fields'] = ", ".join(temp_fields)

return sql % options

def _generate_insert_temp_fields(self, concrete, column):
string = '"%s"' % column
if hasattr(concrete, 'copy_template'):
string = concrete.copy_template % dict(name=column)
template_method = 'copy_%s_template' % concrete.name
if hasattr(self.model, template_method):
template = getattr(self.model(), template_method)()
string = template % dict(name=column)
return string
32 changes: 32 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,35 @@ class Meta:
def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'


class LimitedMockObject(models.Model):
name = models.CharField(max_length=500)
dt = models.DateField(null=True)

class Meta:
app_label = 'tests'

def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'


class OverloadMockObject(models.Model):
name = models.CharField(max_length=500)
lower_name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column='num')
dt = models.DateField(null=True)
parent = models.ForeignKey('MockObject', null=True, default=None)

class Meta:
app_label = 'tests'

def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'

def copy_lower_name_template(self):
return 'lower("%(name)s")'
copy_name_template.copy_type = 'text'

57 changes: 56 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from datetime import date
from .models import MockObject, ExtendedMockObject
from .models import MockObject, ExtendedMockObject, LimitedMockObject,\
OverloadMockObject
from postgres_copy import CopyMapping
from django.test import TestCase

Expand All @@ -18,6 +19,8 @@ def setUp(self):
def tearDown(self):
MockObject.objects.all().delete()
ExtendedMockObject.objects.all().delete()
LimitedMockObject.objects.all().delete()
OverloadMockObject.objects.all().delete()

def test_bad_call(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -57,6 +60,17 @@ def test_bad_field(self):
dict(name1='NAME', number='NUMBER', dt='DATE'),
)

def test_limited_fields(self):
try:
CopyMapping(
LimitedMockObject,
self.name_path,
dict(name='NAME', dt='DATE'),
ignore_headers=['NUMBER']
)
except ValueError:
self.fail("Failed trying to ignore fields")

def test_simple_save(self):
c = CopyMapping(
MockObject,
Expand All @@ -71,6 +85,20 @@ def test_simple_save(self):
date(2012, 1, 1)
)

def test_limited_save(self):
c = CopyMapping(
LimitedMockObject,
self.name_path,
dict(name='NAME', dt='DATE'),
ignore_headers=['NUMBER']
)
c.save()
self.assertEqual(LimitedMockObject.objects.count(), 3)
self.assertEqual(
LimitedMockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

def test_save_foreign_key(self):
c = CopyMapping(
MockObject,
Expand Down Expand Up @@ -217,3 +245,30 @@ def test_save_foreign_key(self):
MockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

def test_overload_save(self):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
overloaded_mapping=dict(lower_name='NAME')
)
c.save()
self.assertEqual(OverloadMockObject.objects.count(), 3)
self.assertEqual(OverloadMockObject.objects.get(name='BEN').number, 1)
self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1)
self.assertEqual(
OverloadMockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)
omo = OverloadMockObject.objects.first()
self.assertEqual(omo.name.lower(), omo.lower_name)

def test_missing_overload_field(self):
with self.assertRaises(ValueError):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
overloaded_mapping=dict(missing='NAME')
)
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ deps =
pep8
pyflakes
coverage
psycopg2
commands =
pep8 postgres_copy
pyflakes postgres_copy
Expand Down