Skip to content

Commit

Permalink
fixes linting
Browse files Browse the repository at this point in the history
  • Loading branch information
LashaO committed Jun 13, 2024
1 parent 6a1ad02 commit 5b2322a
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions wbia/algo/detect/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,26 @@ def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False
'left': 3,
'right': 3,
}

self.reverse_label_map = {
'up':1,
'down':2,
'front':3,
'back':4,
'left':5,
'right':6,
'upfront':7,
'upback':8,
'upleft':9,
'upright':10,
'downfront':11,
'downback':12,
'downleft':13,
'downright':14,
'frontleft':15,
'frontright':16,
'backleft':17,
'backright':18,
'up': 1,
'down': 2,
'front': 3,
'back': 4,
'left': 5,
'right': 6,
'upfront': 7,
'upback': 8,
'upleft': 9,
'upright': 10,
'downfront': 11,
'downback': 12,
'downleft': 13,
'downright': 14,
'frontleft': 15,
'frontright': 16,
'backleft': 17,
'backright': 18,
}

if multilabel:
Expand All @@ -181,8 +181,8 @@ def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False
'''

def process_row(self, row_labels, preds, sort_weights, labels):
multi_labels = labels[row_labels == True]
preds = preds[row_labels == True]
multi_labels = labels[row_labels.astype(bool)]
preds = preds[row_labels.astype(bool)]
# Combine the multi_labels and preds into a list of tuples
label_pred_weight = [(label, pred, sort_weights[label]) for label, pred in zip(multi_labels, preds)]
# Sort by weights first, then by prediction values in descending order
Expand All @@ -195,7 +195,7 @@ def process_row(self, row_labels, preds, sort_weights, labels):
# Extract the labels in the order of weights - top 2 labels
sorted_labels = [best_labels[weight][0] for weight in sorted(best_labels)[:2]]
return sorted_labels

def process_multilabel_preds(self, image_preds, sort_weights, labels, reverse_label_map):
multi_label_matrix = (image_preds > 0.5).cpu().numpy()
sorted_labels = [self.process_row(row, preds, sort_weights, labels) for row, preds in zip(multi_label_matrix, image_preds)]
Expand All @@ -211,7 +211,7 @@ def process_multilabel_preds(self, image_preds, sort_weights, labels, reverse_la
one_hot_tensor = torch.tensor(one_hot_matrix, dtype=torch.float32)

return one_hot_tensor

def forward(self, x):
x = self.model(x)
if self.multilabel:
Expand Down Expand Up @@ -626,7 +626,7 @@ def test_single(filepath_list, weights_path, batch_size=1792, multi=PARALLEL, **
num_classes = len(classes)

# Initialize the model for this run
multilabel = 'multilabel' in weights_path
multilabel = 'multilabel' in weights_path
model = EfficientnetModel(n_class=num_classes, multilabel=multilabel)
# num_ftrs = model.classifier.in_features
# model.classifier = nn.Linear(num_ftrs, num_classes)
Expand Down

0 comments on commit 5b2322a

Please sign in to comment.