-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from jessepisel/dev
HuggingFace updates
- Loading branch information
Showing
20 changed files
with
11,084 additions
and
694 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,73 @@ | ||
import json | ||
import numpy as np | ||
|
||
|
||
class ResultBuilder: | ||
def __init__(self): | ||
self.results = dict() | ||
|
||
def build(self, | ||
query_image_labels: np.ndarray, | ||
matched_labels: np.ndarray, | ||
confidence_scores: np.ndarray): | ||
|
||
def build( | ||
self, | ||
query_image_labels: np.ndarray, | ||
matched_labels: np.ndarray, | ||
confidence_scores: np.ndarray, | ||
): | ||
""" | ||
Prepare results in expected form | ||
:param query_image_labels: numpy array of N reference image labels with shape [N] | ||
:param matched_labels: numpy array of labels of matched base images. Given N query images, this should have shape (N, 3) | ||
:param confidence_scores: numpy array of confidence scores for each matched based image. Given N query images, this should have shape (N, 3) | ||
""" | ||
|
||
# validate shapes of inputs | ||
if len(query_image_labels.shape) != 1: | ||
raise ValueError(f'Expected query_image_labels to be 1-dimensional array, got {query_image_labels.shape} instead') | ||
|
||
if matched_labels.shape != (query_image_labels.shape[0],3): | ||
raise ValueError(f'Expected matched_labels to have shape {(query_image_labels.shape[0], 3)}, got {matched_labels.shape} instead') | ||
|
||
if confidence_scores.shape != (query_image_labels.shape[0],3): | ||
raise ValueError(f'Expected confidence_scores to have shape {(query_image_labels.shape[0], 3)}, got {confidence_scores.shape} instead') | ||
|
||
raise ValueError( | ||
f"Expected query_image_labels to be 1-dimensional array, got {query_image_labels.shape} instead" | ||
) | ||
|
||
if matched_labels.shape != (query_image_labels.shape[0], 3): | ||
raise ValueError( | ||
f"Expected matched_labels to have shape {(query_image_labels.shape[0], 3)}, got {matched_labels.shape} instead" | ||
) | ||
|
||
if confidence_scores.shape != (query_image_labels.shape[0], 3): | ||
raise ValueError( | ||
f"Expected confidence_scores to have shape {(query_image_labels.shape[0], 3)}, got {confidence_scores.shape} instead" | ||
) | ||
|
||
for i, x in enumerate(query_image_labels): | ||
labels = matched_labels[i] | ||
confidence = confidence_scores[i] | ||
|
||
result_x = [{'label': labels[j], 'confidence': confidence[j]} for j in range(0,3)] | ||
|
||
|
||
result_x = [ | ||
{"label": labels[j], "confidence": confidence[j]} for j in range(0, 3) | ||
] | ||
|
||
self.results.update({x: result_x}) | ||
|
||
return self | ||
def to_json(self, path: str = '.') -> None: | ||
|
||
def to_json(self, path: str = ".") -> None: | ||
""" | ||
Save results to json file | ||
Save results to json file | ||
:param path: parent directory of result.json file | ||
""" | ||
path = f'{path}/results.json' | ||
with open(path, 'w+') as f: | ||
|
||
path = f"{path}/results.json" | ||
with open(path, "w+") as f: | ||
json.dump(self.results, f) | ||
|
||
def __call__(self, | ||
query_image_labels: np.ndarray, | ||
matched_labels: np.ndarray, | ||
confidence_scores: np.ndarray, | ||
path: str = '.') -> None: | ||
|
||
def __call__( | ||
self, | ||
query_image_labels: np.ndarray, | ||
matched_labels: np.ndarray, | ||
confidence_scores: np.ndarray, | ||
path: str = ".", | ||
) -> None: | ||
""" | ||
Build result and save results to json file | ||
""" | ||
self.build(query_image_labels, matched_labels, confidence_scores) | ||
self.to_json(path) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class ModelConfig: | ||
BACKBONE_MODEL: str = "ResNet50" | ||
BACKBONE_MODEL_WEIGHTS: str = "ResNet50_Weights.IMAGENET1K_V2" | ||
LATENT_SPACE_DIM: int = 8 | ||
FC_IN_FEATURES: int = -1 | ||
|
||
|
||
defaultConfig = ModelConfig() | ||
|
||
vitBaseConfig = ModelConfig( | ||
BACKBONE_MODEL="ViT_B_16", | ||
BACKBONE_MODEL_WEIGHTS="ViT_B_16_Weights.DEFAULT", | ||
LATENT_SPACE_DIM=16, | ||
FC_IN_FEATURES=768, | ||
) | ||
|
||
vitBaseConfigPretrained = ModelConfig( | ||
BACKBONE_MODEL="ViT_B_16", | ||
BACKBONE_MODEL_WEIGHTS="../checkpoints/ViT_B_16_SEISMIC_SGD_28G_M75.pth", | ||
LATENT_SPACE_DIM=16, | ||
FC_IN_FEATURES=768, | ||
) |
File renamed without changes.
Oops, something went wrong.