Skip to content

Commit

Permalink
Merge pull request #195 from neulab/e501
Browse files Browse the repository at this point in the history
Complying with E501

Former-commit-id: ee63b46
  • Loading branch information
Yusuke Oda authored Mar 31, 2022
2 parents 84792e3 + c444740 commit bc49336
Show file tree
Hide file tree
Showing 28 changed files with 1,383 additions and 294 deletions.
14 changes: 10 additions & 4 deletions explainaboard/explainaboard_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def main():
type=str,
required=True,
nargs="+",
help="the directories of system outputs. Multiple one should be separated by space, for example: system1 system2",
help=(
"the directories of system outputs. Multiple one should be separated by "
"space, for example: system1 system2"
),
)

parser.add_argument(
Expand Down Expand Up @@ -84,11 +87,13 @@ def main():
# Checks on inputs
if num_outputs > 2:
raise ValueError(
f'ExplainaBoard currently only supports 1 or 2 system outputs, but received {num_outputs}'
f'ExplainaBoard currently only supports 1 or 2 system outputs, but '
f'received {num_outputs}'
)
if task not in TaskType.list():
raise ValueError(
f'Task name {task} was not recognized. ExplainaBoard currently supports: {TaskType.list()}'
f'Task name {task} was not recognized. ExplainaBoard currently supports: '
f'{TaskType.list()}'
)

# Read in data and check validity
Expand All @@ -101,7 +106,8 @@ def main():
num0 = len(system_datasets[0])
num1 = len(system_datasets[1])
raise ValueError(
f'Data must be identical for pairwise analysis, but length of files {system_datasets[0]} ({num0}) != {system_datasets[1]} ({num1})'
f'Data must be identical for pairwise analysis, but length of files '
f'{system_datasets[0]} ({num0}) != {system_datasets[1]} ({num1})'
)
if (
loaders[0].user_defined_features_configs
Expand Down
69 changes: 43 additions & 26 deletions explainaboard/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str:
"""
_arrow_to_datasets_dtype takes a pyarrow.DataType and converts it to a datasets string dtype.
In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
_arrow_to_datasets_dtype takes a pyarrow.DataType and converts it to a datasets
string dtype. In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
"""

if pa.types.is_null(arrow_type):
Expand Down Expand Up @@ -74,11 +74,11 @@ def string_to_arrow(datasets_dtype: str) -> pa.DataType:
"""
string_to_arrow takes a datasets string dtype and converts it to a pyarrow.DataType.
In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
This is necessary because the datasets.Value() primitive type is constructed using a string dtype
Value(dtype=str)
But Features.type (via `get_nested_type()` expects to resolve Features into a pyarrow Schema,
which means that each Value() must be able to resolve into a corresponding pyarrow.DataType, which is the
purpose of this function.
This is necessary because the datasets.Value() primitive type is constructed using a
string dtype Value(dtype=str)
But Features.type (via `get_nested_type()` expects to resolve Features into a
pyarrow Schema, which means that each Value() must be able to resolve into a
corresponding pyarrow.DataType, which is the purpose of this function.
"""
timestamp_regex = re.compile(r"^timestamp\[(.*)\]$")
timestamp_matches = timestamp_regex.search(datasets_dtype)
Expand All @@ -97,16 +97,21 @@ def string_to_arrow(datasets_dtype: str) -> pa.DataType:
return pa.timestamp(internals_matches.group(1), internals_matches.group(2))
else:
raise ValueError(
f"{datasets_dtype} is not a validly formatted string representation of a pyarrow timestamp."
f"Examples include timestamp[us] or timestamp[us, tz=America/New_York]"
f"See: https://arrow.apache.org/docs/python/generated/pyarrow.timestamp.html#pyarrow.timestamp"
f"""
{datasets_dtype} is not a validly formatted string representation of a pyarrow
timestamp. Examples include timestamp[us] or timestamp[us, tz=America/New_York]
See:
https://arrow.apache.org/docs/python/generated/pyarrow.timestamp.html#pyarrow.timestamp
"""
)
elif datasets_dtype not in pa.__dict__:
if str(datasets_dtype + "_") not in pa.__dict__:
raise ValueError(
f"Neither {datasets_dtype} nor {datasets_dtype + '_'} seems to be a pyarrow data type. "
f"Please make sure to use a correct data type, see: "
f"https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions"
f"""
Neither {datasets_dtype} nor {datasets_dtype + '_'} seems to be a pyarrow data type.
Please make sure to use a correct data type, see:
https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
"""
)
arrow_data_factory_function_name = str(datasets_dtype + "_")
else:
Expand All @@ -119,17 +124,23 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool) -> tuple[Any, boo
"""
Cast pytorch/tensorflow/pandas objects to python numpy array/lists.
It works recursively.
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
To avoid iterating over possibly long lists, it first checks if the first element
that is not None has to be casted.
If the first element needs to be casted, then all the elements of the list will be
casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating
over every single token for example.
Args:
obj: the object (nested struct) to cast
only_1d_for_numpy (bool): whether to keep the full multi-dim tensors as multi-dim numpy arrays, or convert them to
nested lists of 1-dimensional numpy arrays. This can be useful to keep only 1-d arrays to instantiate Arrow arrays.
Indeed Arrow only support converting 1-dimensional array values.
only_1d_for_numpy (bool): whether to keep the full multi-dim tensors as
multi-dim numpy arrays, or convert them to nested lists of 1-dimensional
numpy arrays. This can be useful to keep only 1-d arrays to instantiate
Arrow arrays. Indeed Arrow only support converting 1-dimensional array
values.
Returns:
casted_obj: the casted object
has_changed (bool): True if the object has been changed, False if it is identical
has_changed (bool): True if the object has been changed, False if it is
identical
"""

if config.TF_AVAILABLE and "tensorflow" in sys.modules:
Expand Down Expand Up @@ -240,9 +251,12 @@ def cast_to_python_objects(obj: Any, only_1d_for_numpy=False) -> Any:
"""
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
It works recursively.
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
To avoid iterating over possibly long lists, it first checks if the first element
that is not None has to be casted.
If the first element needs to be casted, then all the elements of the list will be
casted, otherwise they'll stay the same.
This trick allows to cast objects that contain tokenizers outputs without iterating
over every single token for example.
Args:
obj: the object (nested struct) to cast
Returns:
Expand Down Expand Up @@ -552,7 +566,8 @@ def encode_example(self, value):

def encode_nested_example(schema, obj):
"""Encode a nested example.
This is used since some features (in particular ClassLabel) have some logic during encoding.
This is used since some features (in particular ClassLabel) have some logic during
encoding.
"""
# Nested structures: we allow dict, list/tuples, sequences
if isinstance(schema, dict):
Expand Down Expand Up @@ -598,10 +613,12 @@ def encode_nested_example(schema, obj):
else None
)
# Object with special encoding:
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
# ClassLabel will convert from string to int,
# TranslationVariableLanguages does some checks
elif isinstance(schema, (ClassLabel, Value)):
return schema.encode_example(obj)
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
# Other object should be directly convertible to a native Arrow type
# (like Translation and Translation)
return obj


Expand Down
14 changes: 5 additions & 9 deletions explainaboard/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ class SysOutputInfo:
download_link (str): the url of the system output.
paper (Paper, optional): the published paper of the system.
features (Features, optional): the features used to describe system output's
column type.
column type.
is_print_case (bool): Whether or not to print out cases
is_print_confidence_interval (bool): Whether or not to print out confidence intervals
is_print_confidence_interval (bool): Whether or not to print out confidence
intervals
"""

# set in the system_output scripts
Expand Down Expand Up @@ -141,8 +142,8 @@ def _dump_info(self, file):
def from_directory(cls, sys_output_info_dir: str) -> "SysOutputInfo":
"""Create SysOutputInfo from the JSON file in `sys_output_info_dir`.
Args:
sys_output_info_dir (`str`): The directory containing the metadata file. This
should be the root directory of a specific dataset version.
sys_output_info_dir (`str`): The directory containing the metadata file.
This should be the root directory of a specific dataset version.
"""
logger.info("Loading Dataset info from %s", sys_output_info_dir)
if not sys_output_info_dir:
Expand All @@ -158,11 +159,6 @@ def from_directory(cls, sys_output_info_dir: str) -> "SysOutputInfo":
sys_output_info_dict = json.load(f)
return cls.from_dict(sys_output_info_dict)

# @classmethod
# def from_dict(cls, task_name: str, sys_output_info_dict: dict) -> "SysOutputInfo":
# field_names = set(f.name for f in dataclasses.fields(cls))
# return cls(task_name, **{k: v for k, v in sys_output_info_dict.items() if k in field_names})

@classmethod
def from_dict(cls, sys_output_info_dict: dict) -> "SysOutputInfo":
field_names = set(f.name for f in dataclasses.fields(cls))
Expand Down
33 changes: 19 additions & 14 deletions explainaboard/loaders/file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
class FileLoaderField:
"""
Args:
src_name: field name in the source file. use int for tsv column indices and use str for dict keys
src_name: field name in the source file. use int for tsv column indices and use
str for dict keys
target_name: field name expected in the loaded data
dtype: data type of the field in the loaded data. It is only intended for simple type conversion so
it only supports int, float and str. Pass in None to turn off type conversion.
strip_before_parsing: call strip() on strings before casting to either str, int or float. It is
only intended to be used with these three data types. It defaults to True for str. For all other
types, it defaults to False
parser: a custom parser for the field. When called, `data_points[idx][src_name]` is passed in as input,
it is expected to return the parsed result. If parser is not None, `strip_before_parsing` and dtype
will not have any effect
dtype: data type of the field in the loaded data. It is only intended for simple
type conversion so it only supports int, float and str. Pass in None to turn
off type conversion.
strip_before_parsing: call strip() on strings before casting to either str, int
or float. It is only intended to be used with these three data types.
It defaults to True for str. For all other types, it defaults to False
parser: a custom parser for the field. When called, `data_points[idx][src_name]`
is passed in as input, it is expected to return the parsed result.
If parser is not None, `strip_before_parsing` and dtype will not have any
effect.
"""

src_name: Union[int, str]
Expand Down Expand Up @@ -94,7 +97,8 @@ def generate_id(self, parsed_data_point: dict, sample_idx: int):
elif self._id_field_name:
if self._id_field_name not in parsed_data_point:
raise ValueError(
f"The {sample_idx} data point in system outputs file does not have field {self._id_field_name}"
f"The {sample_idx} data point in system outputs file does not have "
f"field {self._id_field_name}"
)
parsed_data_point["id"] = str(parsed_data_point[self._id_field_name])

Expand All @@ -104,16 +108,17 @@ def load_raw(cls, data: str, source: Source) -> Iterable:
fields information to parse the data points.
Args:
data (str): base64 encoded system output content or a path for the system output file
source: source of data
data (str): base64 encoded system output content or a path for the system
output file
source: source of data
"""
raise NotImplementedError(
"load_raw() is not implemented for the base FileLoader"
)

def load(self, data: str, source: Source) -> Iterable[dict]:
"""Load data from source, parse data points with fields information and return an
iterable of data points.
"""Load data from source, parse data points with fields information and return
an iterable of data points.
"""
raw_data = self.load_raw(data, source)
parsed_data_points: list[dict] = []
Expand Down
10 changes: 6 additions & 4 deletions explainaboard/loaders/kg_link_tail_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class KgLinkTailPredictionLoader(Loader):
"""
Validate and Reformat system output file with json format:
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
usage:
please refer to `test_loaders.py`
Expand All @@ -24,14 +24,16 @@ class KgLinkTailPredictionLoader(Loader):

def load(self) -> Iterable[dict]:
"""
:param path_system_output: the path of system output file with following format:
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
:param path_system_output:
the path of system output file with following format:
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
:return: class object
"""
data: list[dict] = []

# TODO(odashi): Avoid potential bug: load_raw returns Iterable[Any] which is not a dict.
# TODO(odashi):
# Avoid potential bug: load_raw returns Iterable[Any] which is not a dict.
raw_data: dict[str, dict[str, str]] = self.file_loaders[ # type: ignore
unwrap(self._file_type)
].load_raw(self._data, self._source)
Expand Down
10 changes: 6 additions & 4 deletions explainaboard/loaders/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class Loader:
data: base64 encoded system output content or a path for the system output file
source: source of data
file type: tsv, json, conll, etc.
file_loaders: a dict of file loaders. To customize the loading process, either implement
a custome FileLoader or override `load()`
file_loaders: a dict of file loaders. To customize the loading process, either
implement a custome FileLoader or override `load()`
"""

_default_source = Source.local_filesystem
Expand All @@ -49,7 +49,8 @@ def __init__(

if self._file_type not in self.file_loaders:
raise NotImplementedError(
f"A file loader for {self._file_type} is not provided. please add it to the file_loaders."
f"A file loader for {self._file_type} is not provided. "
"please add it to the file_loaders."
)

self._user_defined_features_configs: dict = (
Expand All @@ -60,7 +61,8 @@ def __init__(
def user_defined_features_configs(self) -> dict:
if self._user_defined_features_configs is None:
raise Exception(
"User defined features configs are not available (data has not been loaded))"
"User defined features configs are not available "
"(data has not been loaded))"
)
return self._user_defined_features_configs

Expand Down
7 changes: 4 additions & 3 deletions explainaboard/loaders/qa_multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class QAMultipleChoiceLoader(Loader):
"""
Validate and Reformat system output file with json format:
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
usage:
please refer to `test_loaders.py`
Expand All @@ -26,8 +26,9 @@ class QAMultipleChoiceLoader(Loader):

def load(self) -> Iterable[dict]:
"""
:param path_system_output: the path of system output file with following format:
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
:param path_system_output:
the path of system output file with following format:
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
:return: class object
"""
Expand Down
Loading

0 comments on commit bc49336

Please sign in to comment.