-
Notifications
You must be signed in to change notification settings - Fork 16
/
utils.py
executable file
·105 lines (87 loc) · 2.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import csv
import functools
import hashlib
import itertools
import json
import os
import random
import re
import sys
from typing import Any, Dict, List, Set, Tuple
import unicodedata
FIELDSEP = "|"
def makeID(text: str) -> str:
"""
Create a unique ID based on the value of the input text.
WARNING: This is typically used to create prompt IDs, but
because of issues with stray spaces in the prompts,
this may not always produce the ID you are expecting.
"""
textID = hashlib.md5(text.lower().encode()).hexdigest()
return f"prompt_{textID}"
def read_trans_prompts(lines: List[str]) -> List[Tuple[str,str]]:
"""
This reads a file in the shared task format, returns a list of Tuples containing ID and text for each prompt.
"""
ids_prompts = []
first = True
for line in lines:
line = line.strip().lower()
# in a group, the first one is the KEY.
# all others are part of the set.
if len(line) == 0:
first = True
else:
if first:
key, prompt = line.split(FIELDSEP)
ids_prompts.append((key, prompt))
first = False
return ids_prompts
def strip_punctuation(text: str) -> str:
"""
Remove punctuations of several languages, including Japanese.
"""
return "".join(
itertools.filterfalse(lambda x: unicodedata.category(x).startswith("P"), text)
)
def read_transfile(lines: List[str], strip_punc=True, weighted=False) -> Dict[str, Dict[str, float]]:
"""
This reads a file in the shared task format, and returns a dictionary with prompt IDs as
keys, and each key associated with a dictionary of responses.
"""
data = {}
first = True
options = {}
key = ""
for line in lines:
line = line.strip().lower()
# in a group, the first one is the KEY.
# all others are part of the set.
if len(line) == 0:
first = True
if len(key) > 0 and len(options) > 0:
if key in data:
print(f"Warning: duplicate sentence! {key}")
data[key] = options
options= {}
else:
if first:
key, _ = line.strip().split(FIELDSEP)
first = False
else:
# allow that a line may have a number at the end specifying the weight that this element should take.
# this is controlled by the weighted argument.
# gold is REQUIRED to have this weight.
if weighted:
# get text
text, weight = line.strip().split(FIELDSEP)
else:
text = line.strip()
weight = 1
if strip_punc:
text = strip_punctuation(text)
options[text] = float(weight)
# check if there is still an element at the end.
if len(options) > 0:
data[key] = options
return data