Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery refactored master branch #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
# -- Options for HTMLHelp output ---------------------------------------------

# Output file base name for HTML help builder.
htmlhelp_basename = project + "-doc"
htmlhelp_basename = f"{project}-doc"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 178-221 refactored with the following changes:


# -- Options for LaTeX output ------------------------------------------------

Expand All @@ -194,14 +194,21 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, project + ".tex", project + " Documentation", author, "manual"),
(
master_doc,
f"{project}.tex",
f"{project} Documentation",
author,
"manual",
)
]


# -- Options for manual page output ------------------------------------------

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, project, project + " Documentation", [author], 1)]
man_pages = [(master_doc, project, f"{project} Documentation", [author], 1)]

# -- Options for Texinfo output ----------------------------------------------

Expand All @@ -212,14 +219,15 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
(
master_doc,
project,
project + " Documentation",
f"{project} Documentation",
author,
project,
torchmetrics.__docs__,
"Miscellaneous",
),
)
]


# -- Options for Epub output -------------------------------------------------

# Bibliographic Dublin Core info.
Expand Down
3 changes: 1 addition & 2 deletions examples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) ->
def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD) -> Module:
"""Initialize the Transformer encoder."""
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
return transformer_encoder
return nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_user_model_encoder refactored with the following changes:



def user_forward_fn(model: Module, batch: Dict[str, Tensor]) -> Tensor:
Expand Down
8 changes: 2 additions & 6 deletions examples/rouge_score-own_normalizer_and_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def __call__(self, text: str) -> str:
Return:
Normalized python string object
"""
output_text = re.sub(self.pattern, " ", text.lower())

return output_text
return re.sub(self.pattern, " ", text.lower())
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function UserNormalizer.__call__ refactored with the following changes:



class UserTokenizer:
Expand All @@ -65,9 +63,7 @@ def __call__(self, text: str) -> Sequence[str]:
Return:
Tokenized sentence
"""
output_tokens = re.split(self.pattern, text)

return output_tokens
return re.split(self.pattern, text)
Comment on lines -68 to +66
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function UserTokenizer.__call__ refactored with the following changes:



_PREDS = ["hello", "hello world", "world world world"]
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def __init__(
default: Callable = lambda: []
reduce_fn: Optional[str] = "cat"
if mdmc_reduce != "samplewise" and reduce != "samples":
if reduce == "micro":
zeros_shape = []
elif reduce == "macro":
if reduce == "macro":
zeros_shape = [num_classes]
Comment on lines -158 to 161
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function StatScores.__init__ refactored with the following changes:

  • Simplify conditional into switch-like form (switch)

elif reduce == "micro":
zeros_shape = []
else:
raise ValueError(f'Wrong reduce="{reduce}"')
default = lambda: torch.zeros(zeros_shape, dtype=torch.long)
Expand Down
12 changes: 4 additions & 8 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,7 @@ def _merge_compute_groups(self) -> None:

# Re-index groups
temp = deepcopy(self._groups)
self._groups = {}
for idx, values in enumerate(temp.values()):
self._groups[idx] = values
self._groups = dict(enumerate(temp.values()))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function MetricCollection._merge_compute_groups refactored with the following changes:


@staticmethod
def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
Expand Down Expand Up @@ -310,9 +308,7 @@ def add_metrics(
# prepare for optional additions
metrics = list(metrics)
remain: list = []
for m in additional_metrics:
(metrics if isinstance(m, Metric) else remain).append(m)

(metrics if isinstance(m, Metric) else remain).extend(iter(additional_metrics))
Comment on lines -313 to +311
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function MetricCollection.add_metrics refactored with the following changes:

if remain:
rank_zero_warn(
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored."
Expand Down Expand Up @@ -369,7 +365,7 @@ def _init_compute_groups(self) -> None:
simply initialize each metric in the collection as its own group
"""
if isinstance(self._enable_compute_groups, list):
self._groups = {i: k for i, k in enumerate(self._enable_compute_groups)}
self._groups = dict(enumerate(self._enable_compute_groups))
Comment on lines -372 to +368
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function MetricCollection._init_compute_groups refactored with the following changes:

for v in self._groups.values():
for metric in v:
if metric not in self:
Expand Down Expand Up @@ -453,5 +449,5 @@ def __repr__(self) -> str:
if self.prefix:
repr_str += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
if self.postfix:
repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
repr_str += f"{'' if self.prefix else ','}\n postfix={self.postfix}"
Comment on lines -456 to +452
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function MetricCollection.__repr__ refactored with the following changes:

return repr_str + "\n)"
25 changes: 11 additions & 14 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,15 @@ def compute_area(input: List[Any], iou_type: str = "bbox") -> Tensor:

Default output for empty input is torch.Tensor([])
"""
if len(input) == 0:
if not input:

return torch.Tensor([])

if iou_type == "bbox":
return box_area(torch.stack(input))
elif iou_type == "segm":

input = [{"size": i[0], "counts": i[1]} for i in input]
area = torch.tensor(mask_utils.area(input).astype("float"))

return area
return torch.tensor(mask_utils.area(input).astype("float"))
Comment on lines -45 to +53
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function compute_area refactored with the following changes:

else:
raise Exception(f"IOU type {iou_type} is not supported")

Expand Down Expand Up @@ -191,9 +188,7 @@ def _input_validator(
def _fix_empty_tensors(boxes: Tensor) -> Tensor:
"""Empty tensors can cause problems in DDP mode, this methods corrects them."""

if boxes.numel() == 0 and boxes.ndim == 1:
return boxes.unsqueeze(0)
return boxes
return boxes.unsqueeze(0) if boxes.numel() == 0 and boxes.ndim == 1 else boxes
Comment on lines -194 to +191
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _fix_empty_tensors refactored with the following changes:



class MeanAveragePrecision(Metric):
Expand Down Expand Up @@ -462,7 +457,7 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor:
gt = [gt[i] for i in gt_label_mask]
det = [det[i] for i in det_label_mask]

if len(gt) == 0 or len(det) == 0:
if not gt or not det:
Comment on lines -465 to +460
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function MeanAveragePrecision._compute_iou refactored with the following changes:

return Tensor([])

# Sort by scores and use only max detections
Expand All @@ -475,8 +470,7 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor:
if len(det) > max_det:
det = det[:max_det]

ious = compute_iou(det, gt, self.iou_type).to(self.device)
return ious
return compute_iou(det, gt, self.iou_type).to(self.device)

def __evaluate_image_gt_no_preds(
self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], nb_iou_thrs: int
Expand Down Expand Up @@ -574,7 +568,7 @@ def _evaluate_image(

gt = [gt[i] for i in gt_label_mask]
det = [det[i] for i in det_label_mask]
if len(gt) == 0 and len(det) == 0:
if not gt and not det:
Comment on lines -577 to +571
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function MeanAveragePrecision._evaluate_image refactored with the following changes:

return None
if isinstance(det, dict):
det = [det]
Expand Down Expand Up @@ -706,8 +700,11 @@ def _summarize(
else:
prec = prec[:, :, area_inds, mdet_inds]

mean_prec = torch.tensor([-1.0]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1])
return mean_prec
return (
torch.tensor([-1.0])
if len(prec[prec > -1]) == 0
else torch.mean(prec[prec > -1])
)
Comment on lines -709 to +707
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function MeanAveragePrecision._summarize refactored with the following changes:


def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResults]:
"""Calculate the precision and recall for all supplied classes to calculate mAP/mAR.
Expand Down
11 changes: 6 additions & 5 deletions src/torchmetrics/functional/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def permutation_invariant_training(
Reference:
[1] `Permutation Invariant Training of Deep Models`_
"""
if preds.shape[0:2] != target.shape[0:2]:
if preds.shape[:2] != target.shape[:2]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function permutation_invariant_training refactored with the following changes:

raise RuntimeError(
"Predictions and targets are expected to have the same shape at the batch and speaker dimensions"
)
Expand All @@ -141,7 +141,7 @@ def permutation_invariant_training(
raise ValueError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead")

# calculate the metric matrix
batch_size, spk_num = target.shape[0:2]
batch_size, spk_num = target.shape[:2]
metric_mtx = None
for target_idx in range(spk_num): # we have spk_num speeches in target in each sample
for preds_idx in range(spk_num): # we have spk_num speeches in preds in each sample
Expand All @@ -157,7 +157,7 @@ def permutation_invariant_training(
# find best
op = torch.max if eval_func == "max" else torch.min
if spk_num < 3 or not _SCIPY_AVAILABLE:
if spk_num >= 3 and not _SCIPY_AVAILABLE:
if spk_num >= 3:
warn(f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance")

best_metric, best_perm = _find_best_perm_by_exhaustive_method(metric_mtx, op)
Expand All @@ -177,5 +177,6 @@ def pit_permutate(preds: Tensor, perm: Tensor) -> Tensor:
Returns:
Tensor: the permutated version of estimate
"""
preds_pmted = torch.stack([torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)])
return preds_pmted
return torch.stack(
[torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)]
)
Comment on lines -180 to +182
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function pit_permutate refactored with the following changes:

7 changes: 2 additions & 5 deletions src/torchmetrics/functional/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def signal_distortion_ratio(
"provided by Pytorch is used.",
UserWarning,
)
elif not _TORCH_GREATER_EQUAL_1_8:
else:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function signal_distortion_ratio refactored with the following changes:

warnings.warn(
"The `use_cg_iter` parameter of `SDR` requires a Pytorch version >= 1.8. "
"To make this this warning disappear, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. "
Expand All @@ -230,10 +230,7 @@ def signal_distortion_ratio(
ratio = coh / (1 - coh)
val = 10.0 * torch.log10(ratio)

if preds_dtype == torch.float64:
return val
else:
return val.float()
return val if preds_dtype == torch.float64 else val.float()


def scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
Expand Down
9 changes: 5 additions & 4 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _mode(
<DataType.MULTICLASS: 'multi-class'>
"""

mode = _check_classification_inputs(
return _check_classification_inputs(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _mode refactored with the following changes:

preds,
target,
threshold=threshold,
Expand All @@ -65,7 +65,6 @@ def _mode(
multiclass=multiclass,
ignore_index=ignore_index,
)
return mode


def _accuracy_update(
Expand Down Expand Up @@ -394,7 +393,9 @@ def accuracy(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1):
if average in {"macro", "weighted", "none", None} and (
(not num_classes or num_classes < 1)
):
Comment on lines -397 to +398
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function accuracy refactored with the following changes:

raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.")

allowed_mdmc_average = [None, "samplewise", "global"]
Expand All @@ -409,7 +410,7 @@ def accuracy(

preds, target = _input_squeeze(preds, target)
mode = _mode(preds, target, threshold, top_k, num_classes, multiclass, ignore_index)
reduce = "macro" if average in ["weighted", "none", None] else average
reduce = "macro" if average in {"weighted", "none", None} else average

if subset_accuracy and _check_subset_validity(mode):
correct, total = _subset_accuracy_update(preds, target, threshold, top_k, ignore_index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ def _average_precision_compute_with_precision_recall(
if num_classes == 1:
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])

res = []
for p, r in zip(precision, recall):
res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1]))
res = [
-torch.sum((r[1:] - r[:-1]) * p[:-1])
for p, r in zip(precision, recall)
]
Comment on lines -154 to +157
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _average_precision_compute_with_precision_recall refactored with the following changes:


# Reduce
if average in ("macro", "weighted"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
_, _, mode = _input_format_classification(preds, target)

if mode == DataType.BINARY:
if not ((0 <= preds) * (preds <= 1)).all():
if not ((preds >= 0) * (preds <= 1)).all():
preds = preds.sigmoid()
confidences, accuracies = preds, target
elif mode == DataType.MULTICLASS:
if not ((0 <= preds) * (preds <= 1)).all():
if not ((preds >= 0) * (preds <= 1)).all():
Comment on lines -146 to +150
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _ce_update refactored with the following changes:

preds = preds.softmax(dim=1)
confidences, predictions = preds.max(dim=1)
accuracies = predictions.eq(target)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tens
"""

confmat = _confusion_matrix_compute(confmat)
confmat = confmat.float() if not confmat.is_floating_point() else confmat
confmat = confmat if confmat.is_floating_point() else confmat.float()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _cohen_kappa_compute refactored with the following changes:

n_classes = confmat.shape[0]
sum0 = confmat.sum(dim=0, keepdim=True)
sum1 = confmat.sum(dim=1, keepdim=True)
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def _confusion_matrix_update(
minlength = num_classes**2

bins = _bincount(unique_mapping, minlength=minlength)
if multilabel:
confmat = bins.reshape(num_classes, 2, 2)
else:
confmat = bins.reshape(num_classes, num_classes)
return confmat
return (
bins.reshape(num_classes, 2, 2)
if multilabel
else bins.reshape(num_classes, num_classes)
)
Comment on lines -50 to +54
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _confusion_matrix_update refactored with the following changes:



def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor:
Expand Down Expand Up @@ -98,7 +98,7 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None)
if normalize not in allowed_normalize:
raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}")
if normalize is not None and normalize != "none":
confmat = confmat.float() if not confmat.is_floating_point() else confmat
confmat = confmat if confmat.is_floating_point() else confmat.float()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _confusion_matrix_compute refactored with the following changes:

if normalize == "true":
confmat = confmat / confmat.sum(axis=1, keepdim=True)
elif normalize == "pred":
Expand Down
11 changes: 5 additions & 6 deletions src/torchmetrics/functional/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ def dice_score(
if zero_division != nan_score:
rank_zero_warn(f"Deprecated parameter. `nan_score` converted to integer {zero_division}.")

ignore_index = None
if not bg:
ignore_index = 0

ignore_index = None if bg else 0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function dice_score refactored with the following changes:

return dice(
preds,
target,
Expand Down Expand Up @@ -272,7 +269,9 @@ def dice(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1):
if average in {"macro", "weighted", "none", None} and (
(not num_classes or num_classes < 1)
):
Comment on lines -275 to +274
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function dice refactored with the following changes:

raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.")

allowed_mdmc_average = [None, "samplewise", "global"]
Expand All @@ -286,7 +285,7 @@ def dice(
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

preds, target = _input_squeeze(preds, target)
reduce = "macro" if average in ("weighted", "none", None) else average
reduce = "macro" if average in {"weighted", "none", None} else average

tp, fp, _, fn = _stat_scores_update(
preds,
Expand Down
Loading