-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsql_test_base.py
173 lines (139 loc) · 5.58 KB
/
sql_test_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from os.path import dirname, join
from numpy.testing import assert_almost_equal
from sqlalchemy import asc
from snowexsql.db import get_db, initialize
def pytest_generate_tests(metafunc):
"""
Function used to parametrize functions. If the function is in the
params keys then run it. Otherwise run all the tests normally.
"""
# Were params provided?
if hasattr(metafunc.cls, 'params'):
if metafunc.function.__name__ in metafunc.cls.params.keys():
funcarglist = metafunc.cls.params[metafunc.function.__name__]
argnames = sorted(funcarglist[0])
metafunc.parametrize(
argnames, [[funcargs[name] for name in argnames] for funcargs in funcarglist]
)
class DBSetup:
"""
Base class for all our tests. Ensures that we clean up after every class that's run
"""
@classmethod
def setup_class(self):
"""
Setup the database one time for testing
"""
self.db = 'localhost/test'
self.data_dir = join(dirname(__file__), 'data')
creds = join(dirname(__file__), 'credentials.json')
self.engine, self.session, self.metadata = get_db(self.db, credentials=creds, return_metadata=True)
initialize(self.engine)
@classmethod
def teardown_class(self):
"""
Remove the databse
"""
self.metadata.drop_all(bind=self.engine)
self.session.close() # optional, depends on use case
def teardown(self):
self.session.flush()
self.session.rollback()
class TableTestBase(DBSetup):
"""
Test any table by picking
"""
# Class to use to upload the data
UploaderClass = None
# Positional arguments to pass to the uploader class
args = []
# Keyword args to pass to the uploader class
kwargs = {}
# Always define this using a table class from data.py and is used for ORM
TableClass = None
# First filter to be applied is count_attribute == data_name
count_attribute = 'type'
# Define params which is a dictionary of test names and their args
params = {
'test_count': [dict(data_name=None, expected_count=None)],
'test_value': [
dict(data_name=None, attribute_to_check=None, filter_attribute=None, filter_value=None, expected=None)],
'test_unique_count': [dict(data_name=None, attribute_to_count=None, expected_count=None)]
}
@classmethod
def setup_class(self):
"""
Setup the database one time for testing
"""
super().setup_class()
# Batches always provide a list of files
if type(self.args[0]) == list:
self.args[0] = [join(self.data_dir, f) for f in self.args[0]]
# Single uploaders only upload a single file
else:
self.args[0] = join(self.data_dir, self.args[0])
# In case we have a smp_log file make it point to the data folder too
if 'smp_log_f' in self.kwargs.keys():
if self.kwargs['smp_log_f'] != None:
self.kwargs['smp_log_f'] = join(self.data_dir, self.kwargs['smp_log_f'])
self.kwargs['db_name'] = self.db
self.kwargs['credentials'] = join(dirname(__file__), 'credentials.json')
u = self.UploaderClass(*self.args, **self.kwargs)
# Allow for batches and single upload
if 'batch' in self.UploaderClass.__name__.lower():
u.push()
else:
u.submit(self.session)
def get_query(self, filter_attribute, filter_value, query=None):
"""
Return the base query using an attribute and value that it is supposed
to be
Args:
filter_attribute: Name of attribute to search for
filter_value: Value that attribute should be to filter db search
query: If were extended a query use it instead of forming a new one
Return:
q: Uncompiled SQLalchemy Query object
"""
if query is None:
query = self.session.query(self.TableClass)
fa = getattr(self.TableClass, filter_attribute)
q = query.filter(fa == filter_value).order_by(asc(fa))
return q
def test_count(self, data_name, expected_count):
"""
Test the record count of a data type
"""
q = self.get_query(self.count_attribute, data_name)
records = q.all()
assert len(records) == expected_count
def test_value(self, data_name, attribute_to_check, filter_attribute, filter_value, expected):
"""
Test that the first value in a filtered record search is as expected
"""
# Filter to the data type were querying
q = self.get_query(self.count_attribute, data_name)
# Add another filter by some attribute
q = self.get_query(filter_attribute, filter_value, query=q)
records = q.all()
if records:
received = getattr(records[0], attribute_to_check)
else:
received = None
try:
received = float(received)
except:
pass
if type(received) == float:
assert_almost_equal(received, expected, 6)
else:
assert received == expected
def test_unique_count(self, data_name, attribute_to_count, expected_count):
"""
Test that the number of unique values in a given attribute is as expected
"""
# Add another filter by some attribute
q = self.get_query(self.count_attribute, data_name)
records = q.all()
received = len(set([getattr(r, attribute_to_count) for r in records]))
assert received == expected_count