Skip to content

Commit 69703e9

Browse files
committed
bug fixes and enhancements
1 parent 5ea8f58 commit 69703e9

File tree

9 files changed

+107
-68
lines changed

9 files changed

+107
-68
lines changed

redistricting/gui/rdsdockwidget.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def plan(self, value: RdsPlan):
6060
self._plan.nameChanged.connect(self.planNameChanged)
6161
self.lblPlanName.setText(self._plan.name)
6262

63-
def planNameChanged(self, name):
64-
self.lblPlanName.setText(name)
63+
def planNameChanged(self):
64+
if self.sender() == self.plan:
65+
self.lblPlanName.setText(self.plan.name)
6566

6667
def btnHelpClicked(self):
6768
showHelp(self.helpContext)

redistricting/gui/wzpeditplangeofields.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def fieldChanged(self, field):
151151
not self.cmbAddlGeoField.isExpression() or self.cmbAddlGeoField.isValidExpression()))
152152

153153
def addField(self):
154-
field, isExpression, isValid = self.cmbAddlGeoField.currentField()
154+
field, _, isValid = self.cmbAddlGeoField.currentField()
155155
if not isValid:
156156
return
157157

158158
layer = self.field('sourceLayer')
159-
self.fieldsModel.appendField(layer, field, isExpression)
159+
self.fieldsModel.appendField(layer, field)

redistricting/models/field.py

+3
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ def getName(self, feature_or_key: Union[QgsFeature, str]) -> str:
279279

280280
def makeJoin(self):
281281
rel = self.getRelation()
282+
if rel is None:
283+
return None
284+
282285
pair = rel.fieldPairs()
283286
if len(pair) > 1:
284287
return None

redistricting/services/planbuilder.py

+2
Original file line numberDiff line numberDiff line change
@@ -215,5 +215,7 @@ def createPlan(self, createLayers=True, planParent: Optional[QObject] = None) ->
215215

216216
if createLayers:
217217
self.createLayers(plan)
218+
elif self._geoPackagePath and self._geoPackagePath.exists():
219+
plan.addLayersFromGeoPackage(self._geoPackagePath)
218220

219221
return plan

redistricting/services/planeditor.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from typing import (
2626
Iterable,
2727
Optional,
28-
Set,
2928
overload
3029
)
3130

@@ -70,7 +69,7 @@ def __init__(self, parent: QObject = None, planUpdater: DistrictUpdater = None):
7069
self._modifiedFields = set()
7170

7271
@property
73-
def modifiedFields(self) -> Set[str]:
72+
def modifiedFields(self) -> set[str]:
7473
return self._modifiedFields
7574

7675
def setProgress(self, progress: float):
@@ -150,7 +149,7 @@ def _addFieldToLayer(self, layer: QgsVectorLayer, fieldOrFieldName, fieldType=No
150149
def _updatePopFields(self):
151150
if self._plan.distLayer:
152151
layer = self._plan.distLayer
153-
addedFields: Set[RdsField] = set(self._popFields) - set(self._plan.popFields)
152+
addedFields: list[RdsField] = [f for f in self._popFields if f not in self._plan.popFields]
154153
if addedFields:
155154
self._addFieldToLayer(layer, [f.makeQgsField() for f in addedFields])
156155

@@ -160,11 +159,11 @@ def _updateDataFields(self):
160159
if self._plan.distLayer:
161160
layer = self._plan.distLayer
162161

163-
addedFields: Set[RdsDataField] = set(self._dataFields) - set(self._plan.dataFields)
162+
addedFields: list[RdsDataField] = [f for f in self._dataFields if f not in self._plan.dataFields]
164163
if addedFields:
165164
self._addFieldToLayer(layer, [f.makeQgsField() for f in addedFields])
166165

167-
removedFields: Set[RdsDataField] = set(self._plan.dataFields) - set(self._dataFields)
166+
removedFields: list[RdsDataField] = [f for f in self._plan.dataFields if f not in self._dataFields]
168167
if removedFields:
169168
provider = layer.dataProvider()
170169
for f in removedFields:
@@ -177,7 +176,7 @@ def _updateDataFields(self):
177176

178177
def _updateGeoFields(self):
179178
def removeFields():
180-
removedFields: Set[RdsField] = set(self._plan.geoFields) - set(self._geoFields)
179+
removedFields: list[RdsField] = [f for f in self._plan.geoFields if f not in self._geoFields]
181180
if removedFields:
182181
provider = self._assignLayer.dataProvider()
183182
fields = self._assignLayer.fields()
@@ -217,7 +216,7 @@ def completed():
217216

218217
saveFields = self._plan.geoFields
219218
layer = self._plan.assignLayer
220-
addedFields: Set[RdsField] = set(self._geoFields) - set(self._plan.geoFields)
219+
addedFields: list[RdsField] = [f for f in self._geoFields if f not in self._plan.geoFields]
221220
if addedFields:
222221
self._addFieldToLayer(layer, [f.makeQgsField() for f in addedFields])
223222

redistricting/services/planlistmodel.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def planListUpdated(self):
136136
self.beginResetModel()
137137
self.endResetModel()
138138

139-
def updatePlan(self, plan):
139+
def updatePlan(self):
140+
plan = self.sender()
140141
idx1 = self.indexFromPlan(plan)
141142
idx2 = self.createIndex(idx1.row(), self.columnCount() - 1)
142143
self.dataChanged.emit(idx1, idx2)

redistricting/services/tasks/updatedistricts.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,17 @@ def calcCutEdges(self) -> Union[int, None]:
130130
# select count of unit pairs where
131131
# 1) assigned districts are different (also takes care of excluding unassigned units from count),
132132
# 2) combination is unique (count a,b but not b,a),
133-
# 3) bounding boxes overlap (using spatial index to minimize more intensive adjacency checks)
133+
# 3) bounding boxes touch or overlap (using spatial index to minimize more intensive adjacency checks)
134134
# 4) units are adjacent at more than a point
135-
sql = "SELECT count(*) " \
136-
"FROM assignments a JOIN assignments b " \
137-
f"ON b.{self.distField} != a.{self.distField} AND b.{self.geoIdField} > a.{self.geoIdField} " \
138-
"AND b.fid IN(SELECT id FROM rtree_assignments_geometry r " \
139-
"WHERE r.minx < st_maxx(a.geometry) and r.maxx >= st_minx(a.geometry) " \
140-
"AND r.miny < st_maxy(a.geometry) and r.maxy >= st_miny(a.geometry)) " \
141-
"AND st_length(st_intersection(a.geometry, b.geometry)) > 0 "
135+
sql = f""""SELECT count(*)
136+
FROM assignments a JOIN assignments b
137+
ON b.{self.distField} != a.{self.distField} AND b.{self.geoIdField} > a.{self.geoIdField}
138+
AND b.fid IN (
139+
SELECT id FROM rtree_assignments_geometry r
140+
WHERE r.minx <= st_maxx(a.geometry) and r.maxx >= st_minx(a.geometry)
141+
AND r.miny <= st_maxy(a.geometry) and r.maxy >= st_miny(a.geometry)
142+
)
143+
AND st_relate(a.geometry, b.geometry, 'F***1****')"""
142144

143145
c = db.execute(sql)
144146
return c.fetchone()[0]

redistricting/utils/layer.py

+61-48
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
QgsFeedback,
4646
QgsVectorLayer
4747
)
48-
from shapely import wkb
4948

5049
from .. import CanceledError
5150
from .intl import tr
@@ -83,6 +82,17 @@ def iterateWithProgress(self, it: Iterator, total: int = 0):
8382
self.checkCanceled()
8483
yield n
8584

85+
def split_provider_url(self):
86+
uri_parts = self._layer.dataProvider().dataSourceUri().split('|')
87+
if len(uri_parts) <= 1:
88+
raise ValueError("Could not determine table name from URI")
89+
database = uri_parts[0]
90+
lexer = shlex.shlex(uri_parts[1])
91+
lexer.whitespace_split = True
92+
lexer.whitespace = '&'
93+
params = dict(pair.split('=', 1) for pair in lexer)
94+
return database, params
95+
8696
def read_qgis(
8797
self,
8898
columns: Optional[list[str]] = None,
@@ -92,58 +102,64 @@ def read_qgis(
92102
filt: Optional[dict[str, Any]] = None
93103
) -> Union[pd.DataFrame, gpd.GeoDataFrame]:
94104
def prog_attributes(f: QgsFeature):
95-
nonlocal count
96-
count += 1
97-
if count % chunksize == 0:
98-
self.checkCanceled()
99-
self.updateProgress(fc, count)
100105
attrs = [f.attribute(i) for i in indices]
101106
if read_geometry:
102107
attrs.append(f.geometry().asWkb().data())
103108
return attrs
104109

105110
if not chunksize:
106111
chunksize = 1
107-
fc = self._layer.featureCount()
108-
count = 0
112+
109113
fields = self._layer.fields()
114+
req = QgsFeatureRequest()
115+
if self._feedback:
116+
req.setFeedback(self._feedback)
117+
110118
if columns is None:
111119
columns = fields.names()
112-
gen = (prog_attributes(f) for f in self._layer.getFeatures())
120+
indices = range(len(columns))
113121
else:
114122
indices = [fields.lookupField(c) for c in columns]
115123
if any((i == -1 for i in indices)):
116124
raise RuntimeError("Bad fields")
117-
req = QgsFeatureRequest()
118125
req.setSubsetOfAttributes(indices)
119-
if filt:
120-
filt = {f: f"{f'{v}' if isinstance(v, str) else v}" for f, v in filt.items()}
121-
expr = f"{'AND'.join(f'{f} = {v}' for f,v in filt.items())}"
122-
req.setFilterExpression(expr)
123-
gen = (prog_attributes(f) for f in self._layer.getFeatures(req))
126+
127+
if filt:
128+
expr = f"{' AND '.join(f'({f} = {v!r})' for f, v in filt.items())}"
129+
req.setFilterExpression(expr)
130+
131+
if order:
132+
clause = QgsFeatureRequest.OrderByClause(order)
133+
orderby = QgsFeatureRequest.OrderBy([clause])
134+
req.setOrderBy(orderby)
124135

125136
if read_geometry:
126-
columns = [*columns, "geometry"]
127-
df = pd.DataFrame(gen, columns=columns)
128-
df['geometry'] = df['geometry'].apply(wkb.loads)
129-
df = gpd.GeoDataFrame(df, geometry="geometry", crs=self._layer.crs().authid())
137+
columns.append('geometry')
138+
df = gpd.GeoDataFrame.from_features(self._layer.getFeatures(req), self._layer.crs().authid(), columns)
130139
else:
140+
gen = (prog_attributes(f) for f in self._layer.getFeatures(req))
131141
df = pd.DataFrame(gen, columns=columns)
132142

133-
if order and order in df.columns:
134-
df = df.sort_values(order).set_index(order)
135-
136143
return df
137144

138145
def gpd_read(
139146
self,
140-
source,
147+
source=None,
141148
fc: int = 0,
142149
chunksize: Optional[int] = None,
143150
filt: Optional[dict[str, Any]] = None,
144151
**kwargs
145152
) -> gpd.GeoDataFrame:
146153
df: gpd.GeoDataFrame = None
154+
155+
if source is None:
156+
source, params = self.split_provider_url()
157+
if "layer" not in kwargs:
158+
kwargs["layer"] = params["layername"]
159+
160+
if filt is not None:
161+
kwargs["where"] = " AND ".join(f"({f} = {v!r})" for f, v in filt.items())
162+
147163
if (fc or chunksize):
148164
if chunksize is None:
149165
divisions = 10
@@ -166,10 +182,6 @@ def gpd_read(
166182
)
167183
else:
168184
df = gpd.read_file(source, **kwargs)
169-
170-
if filt:
171-
for f, v in filt.items():
172-
df = df[df[f] == v]
173185
else:
174186
df = gpd.read_file(source, **kwargs)
175187
self.updateProgress(len(df), len(df))
@@ -184,7 +196,8 @@ def read_layer(
184196
order: Optional[str] = ...,
185197
filt: Optional[dict[str, Any]] = ...,
186198
read_geometry: Literal[False] = ...,
187-
chunksize: int = ...
199+
chunksize: int = ...,
200+
**kwargs
188201
) -> pd.DataFrame:
189202
...
190203

@@ -195,7 +208,8 @@ def read_layer(
195208
order: Optional[str] = ...,
196209
filt: Optional[dict[str, Any]] = ...,
197210
read_geometry: Literal[True] = ...,
198-
chunksize: int = ...
211+
chunksize: int = ...,
212+
**kwargs
199213
) -> gpd.GeoDataFrame:
200214
...
201215

@@ -205,7 +219,8 @@ def read_layer(
205219
order=None,
206220
filt=None,
207221
read_geometry=True,
208-
chunksize=0
222+
chunksize=0,
223+
**kwargs
209224
) -> Union[pd.DataFrame, gpd.GeoDataFrame]:
210225
def makeSqlQuery():
211226
nonlocal filt
@@ -216,7 +231,7 @@ def makeSqlQuery():
216231
cols = ",".join(columns)
217232
if read_geometry and (g := self.getGeometryColumn(self._layer)):
218233
cols += f",{g}"
219-
sql = f"SELECT {cols} from {self.getTableName(self._layer)}"
234+
sql = f"SELECT {cols} FROM {self.getTableName(self._layer)}"
220235
if filt or self._layer.subsetString():
221236
filters = []
222237
if filt:
@@ -248,21 +263,15 @@ def makeSqlQuery():
248263

249264
if self._layer.storageType() in ("GPKG", "OpenFileGDB"):
250265
if read_geometry:
251-
uri_parts = self._layer.dataProvider().dataSourceUri().split('|')
252-
if len(uri_parts) <= 1:
253-
raise ValueError("Could not determine table name from URI")
254-
database = uri_parts[0]
255-
lexer = shlex.shlex(uri_parts[1])
256-
lexer.whitespace_split = True
257-
lexer.whitespace = '&'
258-
params = dict(pair.split('=', 1) for pair in lexer)
259-
df = self.gpd_read(database, self._layer.featureCount(), chunksize,
260-
layer=params['layername'], columns=columns)
266+
database, params = self.split_provider_url()
267+
df = self.gpd_read(database, self._layer.featureCount(), chunksize, filt,
268+
layer=params['layername'], columns=columns, **kwargs)
261269
if order:
262270
df = df.set_index(order).sort_index()
263271
else:
264272
with self._connectSqlOgrSqlite(self._layer.dataProvider()) as db:
265-
df = pd.read_sql(makeSqlQuery(), db, index_col=order, columns=columns, chunksize=chunksize)
273+
df = pd.read_sql(makeSqlQuery(), db, index_col=order,
274+
columns=columns, chunksize=chunksize, **kwargs)
266275
if isinstance(df, Iterator):
267276
df = pd.concat(self.iterateWithProgress(df, total))
268277
elif self._layer.dataProvider().name() in ('spatialite', 'SQLite'):
@@ -272,7 +281,7 @@ def makeSqlQuery():
272281
shlex.split(re.sub(r' \(\w+\)', '', self._layer.dataProvider().dataSourceUri(True)))
273282
)
274283
df = self.gpd_read(params['dbname'], self._layer.featureCount(),
275-
chunksize, layer=params['table'], columns=columns)
284+
chunksize, layer=params['table'], columns=columns, **kwargs)
276285
if order:
277286
df = df.set_index(order).sort_index()
278287
else:
@@ -288,15 +297,17 @@ def makeSqlQuery():
288297
db,
289298
self.getGeometryColumn(self._layer),
290299
index_col=order,
291-
chunksize=chunksize
300+
chunksize=chunksize,
301+
**kwargs
292302
)
293303
else:
294304
df = pd.read_sql(
295305
makeSqlQuery(),
296306
db,
297307
index_col=order,
298308
columns=columns,
299-
chunksize=chunksize
309+
chunksize=chunksize,
310+
**kwargs
300311
)
301312

302313
if isinstance(df, Iterator):
@@ -307,7 +318,8 @@ def makeSqlQuery():
307318
self._layer.featureCount(),
308319
columns=columns,
309320
chunksize=chunksize,
310-
read_geometry=read_geometry
321+
read_geometry=read_geometry,
322+
**kwargs
311323
)
312324
if order:
313325
df = df.set_index(order).sort_index()
@@ -333,11 +345,12 @@ def makeSqlQuery():
333345
delimiter=delimiter,
334346
header=header,
335347
usecols=usecols,
336-
chunksize=chunksize
348+
chunksize=chunksize,
349+
**kwargs
337350
)
338351
df = pd.concat(self.iterateWithProgress(reader.get_chunk(), total))
339352
else:
340-
df = pd.read_csv(uri_parts.path, delimiter=delimiter, header=header, usecols=usecols)
353+
df = pd.read_csv(uri_parts.path, delimiter=delimiter, header=header, usecols=usecols, **kwargs)
341354
if header is None:
342355
if len(columns) == len(df.columns):
343356
df.columns = columns
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import functools
2+
import timeit
3+
4+
import redistricting
5+
6+
7+
# pylint: disable=import-outside-toplevel
8+
class TestLayerReader:
9+
def test_readqgis(self, block_layer):
10+
reader = redistricting.utils.layer.LayerReader(block_layer)
11+
12+
t1 = timeit.timeit(reader.read_qgis, number=20)
13+
print(t1)
14+
15+
def test_gpd_read(self, block_layer):
16+
reader = redistricting.utils.layer.LayerReader(block_layer)
17+
t2 = timeit.timeit(functools.partial(reader.gpd_read, chunksize=-1), number=20)
18+
print(t2)

0 commit comments

Comments
 (0)