-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvalidate_match_csv.py
52 lines (37 loc) · 1.51 KB
/
validate_match_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import csv
import json
from pathlib import Path
from spacy.lang.en import English
"""Validates that the Citation Manifest correctly matches the Match Example to the row of the database"""
CSV_PATH = Path("src/caselaw_extraction/rules/2022_06_30_Citation_Manifest.csv")
def setup_nlp(patterns):
nlp = English()
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
return nlp
def run_nlp(nlp, text):
doc = nlp(text)
return [(ent.text, ent.id_) for ent in doc.ents]
def csv_as_dict(csv_path):
with open(csv_path) as csv_file:
csvreader = csv.reader(csv_file)
headers = next(csvreader)
return [dict(zip(headers, row, strict=False)) for row in csvreader]
def get_patterns(csv_dict):
pattern_strings = [x["pattern"] for x in csv_dict]
return [json.loads(pattern_string) for pattern_string in pattern_strings]
csv_dict = csv_as_dict(CSV_PATH)
patterns = get_patterns(csv_dict)
nlp = setup_nlp(patterns)
run_nlp(nlp, "this is [2023] UKSC 3 you know")
for item in csv_dict:
match = run_nlp(nlp, f"jam {item['match_example']} cake")
if not match[0][0] == item["match_example"]:
msg = f"Matched {match[0][0]!r} which isn't {item['match_example']!r}"
raise RuntimeError(msg)
if not match[0][1] == item["id"]:
msg = f"Matched ID was {match[0][1]!r} which isn't {item['id']!r}"
raise RuntimeError(msg)
if len(match) > 1:
msg = f"{len(match)} matches for {item['match_example']!r}"
raise RuntimeError(msg)