-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored master branch #1
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,8 +93,7 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> | |
def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD) -> Module: | ||
"""Initialize the Transformer encoder.""" | ||
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead) | ||
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | ||
return transformer_encoder | ||
return nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def user_forward_fn(model: Module, batch: Dict[str, Tensor]) -> Tensor: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,9 +42,7 @@ def __call__(self, text: str) -> str: | |
Return: | ||
Normalized python string object | ||
""" | ||
output_text = re.sub(self.pattern, " ", text.lower()) | ||
|
||
return output_text | ||
return re.sub(self.pattern, " ", text.lower()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
class UserTokenizer: | ||
|
@@ -65,9 +63,7 @@ def __call__(self, text: str) -> Sequence[str]: | |
Return: | ||
Tokenized sentence | ||
""" | ||
output_tokens = re.split(self.pattern, text) | ||
|
||
return output_tokens | ||
return re.split(self.pattern, text) | ||
Comment on lines
-68
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
_PREDS = ["hello", "hello world", "world world world"] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,10 +155,10 @@ def __init__( | |
default: Callable = lambda: [] | ||
reduce_fn: Optional[str] = "cat" | ||
if mdmc_reduce != "samplewise" and reduce != "samples": | ||
if reduce == "micro": | ||
zeros_shape = [] | ||
elif reduce == "macro": | ||
if reduce == "macro": | ||
zeros_shape = [num_classes] | ||
Comment on lines
-158
to
161
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
elif reduce == "micro": | ||
zeros_shape = [] | ||
else: | ||
raise ValueError(f'Wrong reduce="{reduce}"') | ||
default = lambda: torch.zeros(zeros_shape, dtype=torch.long) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,9 +219,7 @@ def _merge_compute_groups(self) -> None: | |
|
||
# Re-index groups | ||
temp = deepcopy(self._groups) | ||
self._groups = {} | ||
for idx, values in enumerate(temp.values()): | ||
self._groups[idx] = values | ||
self._groups = dict(enumerate(temp.values())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
@staticmethod | ||
def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: | ||
|
@@ -310,9 +308,7 @@ def add_metrics( | |
# prepare for optional additions | ||
metrics = list(metrics) | ||
remain: list = [] | ||
for m in additional_metrics: | ||
(metrics if isinstance(m, Metric) else remain).append(m) | ||
|
||
(metrics if isinstance(m, Metric) else remain).extend(iter(additional_metrics)) | ||
Comment on lines
-313
to
+311
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
if remain: | ||
rank_zero_warn( | ||
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." | ||
|
@@ -369,7 +365,7 @@ def _init_compute_groups(self) -> None: | |
simply initialize each metric in the collection as its own group | ||
""" | ||
if isinstance(self._enable_compute_groups, list): | ||
self._groups = {i: k for i, k in enumerate(self._enable_compute_groups)} | ||
self._groups = dict(enumerate(self._enable_compute_groups)) | ||
Comment on lines
-372
to
+368
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
for v in self._groups.values(): | ||
for metric in v: | ||
if metric not in self: | ||
|
@@ -453,5 +449,5 @@ def __repr__(self) -> str: | |
if self.prefix: | ||
repr_str += f",\n prefix={self.prefix}{',' if self.postfix else ''}" | ||
if self.postfix: | ||
repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}" | ||
repr_str += f"{'' if self.prefix else ','}\n postfix={self.postfix}" | ||
Comment on lines
-456
to
+452
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return repr_str + "\n)" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,18 +42,15 @@ def compute_area(input: List[Any], iou_type: str = "bbox") -> Tensor: | |
|
||
Default output for empty input is torch.Tensor([]) | ||
""" | ||
if len(input) == 0: | ||
if not input: | ||
|
||
return torch.Tensor([]) | ||
|
||
if iou_type == "bbox": | ||
return box_area(torch.stack(input)) | ||
elif iou_type == "segm": | ||
|
||
input = [{"size": i[0], "counts": i[1]} for i in input] | ||
area = torch.tensor(mask_utils.area(input).astype("float")) | ||
|
||
return area | ||
return torch.tensor(mask_utils.area(input).astype("float")) | ||
Comment on lines
-45
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
else: | ||
raise Exception(f"IOU type {iou_type} is not supported") | ||
|
||
|
@@ -191,9 +188,7 @@ def _input_validator( | |
def _fix_empty_tensors(boxes: Tensor) -> Tensor: | ||
"""Empty tensors can cause problems in DDP mode, this methods corrects them.""" | ||
|
||
if boxes.numel() == 0 and boxes.ndim == 1: | ||
return boxes.unsqueeze(0) | ||
return boxes | ||
return boxes.unsqueeze(0) if boxes.numel() == 0 and boxes.ndim == 1 else boxes | ||
Comment on lines
-194
to
+191
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
class MeanAveragePrecision(Metric): | ||
|
@@ -462,7 +457,7 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: | |
gt = [gt[i] for i in gt_label_mask] | ||
det = [det[i] for i in det_label_mask] | ||
|
||
if len(gt) == 0 or len(det) == 0: | ||
if not gt or not det: | ||
Comment on lines
-465
to
+460
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return Tensor([]) | ||
|
||
# Sort by scores and use only max detections | ||
|
@@ -475,8 +470,7 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor: | |
if len(det) > max_det: | ||
det = det[:max_det] | ||
|
||
ious = compute_iou(det, gt, self.iou_type).to(self.device) | ||
return ious | ||
return compute_iou(det, gt, self.iou_type).to(self.device) | ||
|
||
def __evaluate_image_gt_no_preds( | ||
self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], nb_iou_thrs: int | ||
|
@@ -574,7 +568,7 @@ def _evaluate_image( | |
|
||
gt = [gt[i] for i in gt_label_mask] | ||
det = [det[i] for i in det_label_mask] | ||
if len(gt) == 0 and len(det) == 0: | ||
if not gt and not det: | ||
Comment on lines
-577
to
+571
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return None | ||
if isinstance(det, dict): | ||
det = [det] | ||
|
@@ -706,8 +700,11 @@ def _summarize( | |
else: | ||
prec = prec[:, :, area_inds, mdet_inds] | ||
|
||
mean_prec = torch.tensor([-1.0]) if len(prec[prec > -1]) == 0 else torch.mean(prec[prec > -1]) | ||
return mean_prec | ||
return ( | ||
torch.tensor([-1.0]) | ||
if len(prec[prec > -1]) == 0 | ||
else torch.mean(prec[prec > -1]) | ||
) | ||
Comment on lines
-709
to
+707
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResults]: | ||
"""Calculate the precision and recall for all supplied classes to calculate mAP/mAR. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,7 +131,7 @@ def permutation_invariant_training( | |
Reference: | ||
[1] `Permutation Invariant Training of Deep Models`_ | ||
""" | ||
if preds.shape[0:2] != target.shape[0:2]: | ||
if preds.shape[:2] != target.shape[:2]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
raise RuntimeError( | ||
"Predictions and targets are expected to have the same shape at the batch and speaker dimensions" | ||
) | ||
|
@@ -141,7 +141,7 @@ def permutation_invariant_training( | |
raise ValueError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") | ||
|
||
# calculate the metric matrix | ||
batch_size, spk_num = target.shape[0:2] | ||
batch_size, spk_num = target.shape[:2] | ||
metric_mtx = None | ||
for target_idx in range(spk_num): # we have spk_num speeches in target in each sample | ||
for preds_idx in range(spk_num): # we have spk_num speeches in preds in each sample | ||
|
@@ -157,7 +157,7 @@ def permutation_invariant_training( | |
# find best | ||
op = torch.max if eval_func == "max" else torch.min | ||
if spk_num < 3 or not _SCIPY_AVAILABLE: | ||
if spk_num >= 3 and not _SCIPY_AVAILABLE: | ||
if spk_num >= 3: | ||
warn(f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance") | ||
|
||
best_metric, best_perm = _find_best_perm_by_exhaustive_method(metric_mtx, op) | ||
|
@@ -177,5 +177,6 @@ def pit_permutate(preds: Tensor, perm: Tensor) -> Tensor: | |
Returns: | ||
Tensor: the permutated version of estimate | ||
""" | ||
preds_pmted = torch.stack([torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)]) | ||
return preds_pmted | ||
return torch.stack( | ||
[torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)] | ||
) | ||
Comment on lines
-180
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,7 +212,7 @@ def signal_distortion_ratio( | |
"provided by Pytorch is used.", | ||
UserWarning, | ||
) | ||
elif not _TORCH_GREATER_EQUAL_1_8: | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
warnings.warn( | ||
"The `use_cg_iter` parameter of `SDR` requires a Pytorch version >= 1.8. " | ||
"To make this this warning disappear, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. " | ||
|
@@ -230,10 +230,7 @@ def signal_distortion_ratio( | |
ratio = coh / (1 - coh) | ||
val = 10.0 * torch.log10(ratio) | ||
|
||
if preds_dtype == torch.float64: | ||
return val | ||
else: | ||
return val.float() | ||
return val if preds_dtype == torch.float64 else val.float() | ||
|
||
|
||
def scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,7 @@ def _mode( | |
<DataType.MULTICLASS: 'multi-class'> | ||
""" | ||
|
||
mode = _check_classification_inputs( | ||
return _check_classification_inputs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
preds, | ||
target, | ||
threshold=threshold, | ||
|
@@ -65,7 +65,6 @@ def _mode( | |
multiclass=multiclass, | ||
ignore_index=ignore_index, | ||
) | ||
return mode | ||
|
||
|
||
def _accuracy_update( | ||
|
@@ -394,7 +393,9 @@ def accuracy( | |
if average not in allowed_average: | ||
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") | ||
|
||
if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): | ||
if average in {"macro", "weighted", "none", None} and ( | ||
(not num_classes or num_classes < 1) | ||
): | ||
Comment on lines
-397
to
+398
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") | ||
|
||
allowed_mdmc_average = [None, "samplewise", "global"] | ||
|
@@ -409,7 +410,7 @@ def accuracy( | |
|
||
preds, target = _input_squeeze(preds, target) | ||
mode = _mode(preds, target, threshold, top_k, num_classes, multiclass, ignore_index) | ||
reduce = "macro" if average in ["weighted", "none", None] else average | ||
reduce = "macro" if average in {"weighted", "none", None} else average | ||
|
||
if subset_accuracy and _check_subset_validity(mode): | ||
correct, total = _subset_accuracy_update(preds, target, threshold, top_k, ignore_index) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,9 +151,10 @@ def _average_precision_compute_with_precision_recall( | |
if num_classes == 1: | ||
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) | ||
|
||
res = [] | ||
for p, r in zip(precision, recall): | ||
res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) | ||
res = [ | ||
-torch.sum((r[1:] - r[:-1]) * p[:-1]) | ||
for p, r in zip(precision, recall) | ||
] | ||
Comment on lines
-154
to
+157
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
# Reduce | ||
if average in ("macro", "weighted"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,11 +143,11 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | |
_, _, mode = _input_format_classification(preds, target) | ||
|
||
if mode == DataType.BINARY: | ||
if not ((0 <= preds) * (preds <= 1)).all(): | ||
if not ((preds >= 0) * (preds <= 1)).all(): | ||
preds = preds.sigmoid() | ||
confidences, accuracies = preds, target | ||
elif mode == DataType.MULTICLASS: | ||
if not ((0 <= preds) * (preds <= 1)).all(): | ||
if not ((preds >= 0) * (preds <= 1)).all(): | ||
Comment on lines
-146
to
+150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
preds = preds.softmax(dim=1) | ||
confidences, predictions = preds.max(dim=1) | ||
accuracies = predictions.eq(target) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,7 @@ def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tens | |
""" | ||
|
||
confmat = _confusion_matrix_compute(confmat) | ||
confmat = confmat.float() if not confmat.is_floating_point() else confmat | ||
confmat = confmat if confmat.is_floating_point() else confmat.float() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
n_classes = confmat.shape[0] | ||
sum0 = confmat.sum(dim=0, keepdim=True) | ||
sum1 = confmat.sum(dim=1, keepdim=True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,11 +47,11 @@ def _confusion_matrix_update( | |
minlength = num_classes**2 | ||
|
||
bins = _bincount(unique_mapping, minlength=minlength) | ||
if multilabel: | ||
confmat = bins.reshape(num_classes, 2, 2) | ||
else: | ||
confmat = bins.reshape(num_classes, num_classes) | ||
return confmat | ||
return ( | ||
bins.reshape(num_classes, 2, 2) | ||
if multilabel | ||
else bins.reshape(num_classes, num_classes) | ||
) | ||
Comment on lines
-50
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: | ||
|
@@ -98,7 +98,7 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) | |
if normalize not in allowed_normalize: | ||
raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}") | ||
if normalize is not None and normalize != "none": | ||
confmat = confmat.float() if not confmat.is_floating_point() else confmat | ||
confmat = confmat if confmat.is_floating_point() else confmat.float() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
if normalize == "true": | ||
confmat = confmat / confmat.sum(axis=1, keepdim=True) | ||
elif normalize == "pred": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,10 +90,7 @@ def dice_score( | |
if zero_division != nan_score: | ||
rank_zero_warn(f"Deprecated parameter. `nan_score` converted to integer {zero_division}.") | ||
|
||
ignore_index = None | ||
if not bg: | ||
ignore_index = 0 | ||
|
||
ignore_index = None if bg else 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return dice( | ||
preds, | ||
target, | ||
|
@@ -272,7 +269,9 @@ def dice( | |
if average not in allowed_average: | ||
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") | ||
|
||
if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): | ||
if average in {"macro", "weighted", "none", None} and ( | ||
(not num_classes or num_classes < 1) | ||
): | ||
Comment on lines
-275
to
+274
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") | ||
|
||
allowed_mdmc_average = [None, "samplewise", "global"] | ||
|
@@ -286,7 +285,7 @@ def dice( | |
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") | ||
|
||
preds, target = _input_squeeze(preds, target) | ||
reduce = "macro" if average in ("weighted", "none", None) else average | ||
reduce = "macro" if average in {"weighted", "none", None} else average | ||
|
||
tp, fp, _, fn = _stat_scores_update( | ||
preds, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines
178-221
refactored with the following changes:use-fstring-for-concatenation
)