diff --git a/wbia/algo/detect/efficientnet.py b/wbia/algo/detect/efficientnet.py index 57ab5ceb9..42d2c476d 100644 --- a/wbia/algo/detect/efficientnet.py +++ b/wbia/algo/detect/efficientnet.py @@ -141,26 +141,26 @@ def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False 'left': 3, 'right': 3, } - + self.reverse_label_map = { - 'up':1, - 'down':2, - 'front':3, - 'back':4, - 'left':5, - 'right':6, - 'upfront':7, - 'upback':8, - 'upleft':9, - 'upright':10, - 'downfront':11, - 'downback':12, - 'downleft':13, - 'downright':14, - 'frontleft':15, - 'frontright':16, - 'backleft':17, - 'backright':18, + 'up': 1, + 'down': 2, + 'front': 3, + 'back': 4, + 'left': 5, + 'right': 6, + 'upfront': 7, + 'upback': 8, + 'upleft': 9, + 'upright': 10, + 'downfront': 11, + 'downback': 12, + 'downleft': 13, + 'downright': 14, + 'frontleft': 15, + 'frontright': 16, + 'backleft': 17, + 'backright': 18, } if multilabel: @@ -181,8 +181,8 @@ def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False ''' def process_row(self, row_labels, preds, sort_weights, labels): - multi_labels = labels[row_labels == True] - preds = preds[row_labels == True] + multi_labels = labels[row_labels.astype(bool)] + preds = preds[row_labels.astype(bool)] # Combine the multi_labels and preds into a list of tuples label_pred_weight = [(label, pred, sort_weights[label]) for label, pred in zip(multi_labels, preds)] # Sort by weights first, then by prediction values in descending order @@ -195,7 +195,7 @@ def process_row(self, row_labels, preds, sort_weights, labels): # Extract the labels in the order of weights - top 2 labels sorted_labels = [best_labels[weight][0] for weight in sorted(best_labels)[:2]] return sorted_labels - + def process_multilabel_preds(self, image_preds, sort_weights, labels, reverse_label_map): multi_label_matrix = (image_preds > 0.5).cpu().numpy() sorted_labels = [self.process_row(row, preds, sort_weights, labels) for row, preds in zip(multi_label_matrix, image_preds)] @@ -211,7 +211,7 @@ def process_multilabel_preds(self, image_preds, sort_weights, labels, reverse_la one_hot_tensor = torch.tensor(one_hot_matrix, dtype=torch.float32) return one_hot_tensor - + def forward(self, x): x = self.model(x) if self.multilabel: @@ -626,7 +626,7 @@ def test_single(filepath_list, weights_path, batch_size=1792, multi=PARALLEL, ** num_classes = len(classes) # Initialize the model for this run - multilabel = 'multilabel' in weights_path + multilabel = 'multilabel' in weights_path model = EfficientnetModel(n_class=num_classes, multilabel=multilabel) # num_ftrs = model.classifier.in_features # model.classifier = nn.Linear(num_ftrs, num_classes)