-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from mmcint/main
Adding SectionSeeker Code, ReadME, and Quick start notebook
- Loading branch information
Showing
19 changed files
with
1,836 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Welcome to Section Seeker! | ||
|
||
![Section-Seeker-Logo](assets/reflect_connect.png) | ||
|
||
#### **This notebook and package has been adapted from ThinkOnward's [Reflection Connection Challenge](https://thinkonward.com/app/c/challenges/reflection-connection), which ran in late 2023. The SectionSeeker can be used to train a SiameseNN to identify similar sections to the one a user inputs. This can be extremely useful for seismic interpreters looking for an analog section or basin.** | ||
|
||
|
||
#### Background | ||
|
||
Siamese Neural Networks (SNN) have shown great skill at one-shot learning collections of various images. This challenge asks you to train an algorithm to find similar-looking images of seismic data within a larger corpus using a limited training for eight categories. Your solution will need to match as many different features using these data. This challenge is experimental, so we are keen to see how different participants utilize this framework to build a solution. | ||
|
||
To non-geophysicists, seismic images are mysterious: lots of black-and-white squiggly lines stacked on one another. However, with more experience, different features in the seismic can be identified. These features represent common geology structures: a river channel, a salt pan, or a fault. Recognizing seismic features is no different from when a medical technician recognizes the difference between a heart valve or major artery on an echocardiogram. A geoscientist combines all these features into a hypothesis about how the Earth developed in the survey area. An algorithm that can identify parts of a seismic image will enable geoscientists to build more robust hypotheses and spend more time integrating other pieces of information into a comprehensive model of the Earth. | ||
|
||
#### Getting Started | ||
|
||
Check out the starter notebook for help getting your own SNN up and running! |
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
albumentations==1.3.1 | ||
contextlib2==21.6.0 | ||
joblib==1.3.2 | ||
multiprocess==0.70.15 | ||
numba==0.58.1 | ||
numpy==1.26.1 | ||
nvgpu==0.10.0 | ||
nvidia-ml-py==12.535.108 | ||
opencv-python==4.8.1.78 | ||
packaging==21.3 | ||
pandas==1.5.3 | ||
pathos==0.3.1 | ||
Pillow==10.1.0 | ||
py4j==0.10.9.5 | ||
pyarrow==13.0.0 | ||
pyfunctional==1.4.3 | ||
PyYAML | ||
safetensors==0.4.0 | ||
scikit-image==0.22.0 | ||
scikit-learn==1.3.2 | ||
scipy | ||
torch==2.0.1 | ||
torchvision==0.15.2 | ||
tqdm | ||
ujson==5.8.0 | ||
Werkzeug==3.0.1 | ||
matplotlib==3.8.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import json | ||
import numpy as np | ||
|
||
class ResultBuilder: | ||
def __init__(self): | ||
self.results = dict() | ||
|
||
def build(self, | ||
query_image_labels: np.ndarray, | ||
matched_labels: np.ndarray, | ||
confidence_scores: np.ndarray): | ||
""" | ||
Prepare results in expected form | ||
:param query_image_labels: numpy array of N reference image labels with shape [N] | ||
:param matched_labels: numpy array of labels of matched base images. Given N query images, this should have shape (N, 3) | ||
:param confidence_scores: numpy array of confidence scores for each matched based image. Given N query images, this should have shape (N, 3) | ||
""" | ||
|
||
# validate shapes of inputs | ||
if len(query_image_labels.shape) != 1: | ||
raise ValueError(f'Expected query_image_labels to be 1-dimensional array, got {query_image_labels.shape} instead') | ||
|
||
if matched_labels.shape != (query_image_labels.shape[0],3): | ||
raise ValueError(f'Expected matched_labels to have shape {(query_image_labels.shape[0], 3)}, got {matched_labels.shape} instead') | ||
|
||
if confidence_scores.shape != (query_image_labels.shape[0],3): | ||
raise ValueError(f'Expected confidence_scores to have shape {(query_image_labels.shape[0], 3)}, got {confidence_scores.shape} instead') | ||
|
||
for i, x in enumerate(query_image_labels): | ||
labels = matched_labels[i] | ||
confidence = confidence_scores[i] | ||
|
||
result_x = [{'label': labels[j], 'confidence': confidence[j]} for j in range(0,3)] | ||
|
||
self.results.update({x: result_x}) | ||
|
||
return self | ||
|
||
def to_json(self, path: str = '.') -> None: | ||
""" | ||
Save results to json file | ||
:param path: parent directory of result.json file | ||
""" | ||
|
||
path = f'{path}/results.json' | ||
with open(path, 'w+') as f: | ||
json.dump(self.results, f) | ||
|
||
def __call__(self, | ||
query_image_labels: np.ndarray, | ||
matched_labels: np.ndarray, | ||
confidence_scores: np.ndarray, | ||
path: str = '.') -> None: | ||
""" | ||
Build result and save results to json file | ||
""" | ||
self.build(query_image_labels, matched_labels, confidence_scores) | ||
self.to_json(path) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import codecs | ||
import os | ||
import os.path | ||
import shutil | ||
import string | ||
import sys | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence | ||
|
||
import numpy as np | ||
from scipy.spatial import KDTree | ||
import torch | ||
from PIL import Image | ||
import glob | ||
import cv2 | ||
|
||
from snn.utils import detect_device | ||
from snn.model import SiameseNetwork | ||
|
||
|
||
class ImageSet: | ||
""" | ||
Subscriptapble dataset-like class for loading, storing and processing image collections | ||
:param root: Path to project root directory, which contains data/image_corpus/ or data/query catalog | ||
:param base: Build ImageSet on top of image_corpus if True, else on top of query catalog | ||
:param build: Build ImageSet from filesystem instead of using saved version | ||
:param transform: Callable that will be applied to all images when calling __getitem__() method | ||
:param compatibility_mode: Convert images to PIL.Image before applying transform and returning from __getitime__() method | ||
:param greyscale: Load images in grayscale if True, else use 3-channel RGB | ||
:param normalize: If True, images will be normalized image-wise when loaded from disk | ||
""" | ||
def __init__(self, | ||
root: str, | ||
base: bool = True, | ||
build: bool = False, | ||
transform: Callable = None, | ||
compatibility_mode: bool = False, | ||
greyscale: bool = False, | ||
normalize: bool = True) -> None: | ||
|
||
self.root = root | ||
self.compatibility_mode = compatibility_mode | ||
self.greyscale = greyscale | ||
self.colormode = 'L' if greyscale else 'RGB' | ||
self.transform = transform | ||
self.base = base | ||
self.normalize = normalize | ||
|
||
if build: | ||
self.embeddings = [] | ||
self.data, self.names = self._build() | ||
return | ||
|
||
self.data = self._load() | ||
|
||
|
||
def _build(self) -> Tuple[torch.Tensor, str]: | ||
|
||
dirpath = f"{self.root}/data/{'image_corpus' if self.base else 'query'}" | ||
data = [] | ||
images = [] | ||
names = [] | ||
for filename in glob.glob(f"{dirpath}/*png"): | ||
im = Image.open(filename) | ||
# resize into common shape | ||
im = im.convert(self.colormode).resize((118, 143)) | ||
if self.normalize: | ||
im = cv2.normalize(np.array(im), None, 0.0, 1.0, cv2.NORM_MINMAX, cv2.CV_32FC1) | ||
image = np.array(im, dtype=np.float32) | ||
fname = filename.split('/')[-1] | ||
data.append(image) | ||
names.append(fname) | ||
return torch.from_numpy(np.array(data)), names | ||
|
||
def _load(self) -> Tuple[torch.Tensor, str]: | ||
... | ||
|
||
def save(self) -> None: | ||
... | ||
|
||
def build_embeddings(self, model: SiameseNetwork, device: torch.cuda.device = None): | ||
|
||
if device is None: | ||
device = detect_device() | ||
|
||
with torch.no_grad(): | ||
model.eval() | ||
for img, name in self: | ||
img_input = img.transpose(2,0).transpose(2,1).to(device).unsqueeze(0) | ||
embedding = model.get_embedding(img_input) | ||
self.embeddings.append((embedding, name)) | ||
|
||
return self | ||
|
||
def get_embeddings(self) -> List[Tuple[torch.Tensor, str]]: | ||
if self.embeddings is None: | ||
raise RuntimeError('Embedding collection is empty. Run self.build_embeddings() method to build it') | ||
|
||
return self.embeddings | ||
|
||
def __getitem__(self, index: int) -> Tuple[Any, Any]: | ||
""" | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (image, target) where target is index of the target class. | ||
""" | ||
img = self.data[index] | ||
name = self.names[index] | ||
# doing this so that it is consistent with all other datasets | ||
# to return a PIL Image | ||
if self.compatibility_mode: | ||
img = Image.fromarray(img.numpy(), mode=self.colormode) | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
return img, name | ||
|
||
|
||
class SearchTree: | ||
""" | ||
Wrapper for k-d tree built on image embeddings | ||
:param query_set: instance of base ImageSet with built embedding representation | ||
""" | ||
def __init__(self, query_set: ImageSet) -> None: | ||
embeddings = query_set.get_embeddings() | ||
self.embeddings = np.concatenate([x[0].cpu().numpy() for x in embeddings], axis=0) | ||
self.names = np.array([x[1] for x in embeddings]) | ||
self.kdtree = self._build_kdtree() | ||
|
||
def _build_kdtree(self) -> KDTree: | ||
print('Building KD-Tree from embeddings') | ||
return KDTree(self.embeddings) | ||
|
||
def query(self, anchors: ImageSet, k: int = 3) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | ||
""" | ||
Search for k nearest neighbors of provided anchor embeddings | ||
:param anchors: instance of query (reference) ImageSet with built embedding representation | ||
:returns: tuple of reference_labels, distances to matched label embeddings, matched label embeddings, matched_labels | ||
""" | ||
|
||
reference = anchors.get_embeddings() | ||
reference_embeddings = np.concatenate([x[0].cpu().numpy() for x in reference], axis=0) | ||
reference_labels = np.array([x[1] for x in reference]) | ||
|
||
distances, indices = self.kdtree.query(reference_embeddings, k=k, workers=-1) | ||
return reference_labels, distances, self.embeddings[indices], self.names[indices] | ||
|
||
def __call__(self, *args, **kwargs) -> Any: | ||
return self.query(*args, **kwargs) | ||
|
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from dataclasses import dataclass | ||
|
||
@dataclass | ||
class ModelConfig: | ||
BACKBONE_MODEL: str = 'ResNet50' | ||
BACKBONE_MODEL_WEIGHTS: str = 'ResNet50_Weights.IMAGENET1K_V2' | ||
LATENT_SPACE_DIM: int = 8 | ||
FC_IN_FEATURES: int = -1 | ||
|
||
|
||
defaultConfig = ModelConfig() | ||
|
||
vitBaseConfig = ModelConfig(BACKBONE_MODEL = 'ViT_B_16', | ||
BACKBONE_MODEL_WEIGHTS = 'ViT_B_16_Weights.DEFAULT', | ||
LATENT_SPACE_DIM = 16, | ||
FC_IN_FEATURES = 768) | ||
|
||
vitBaseConfigPretrained = ModelConfig(BACKBONE_MODEL = 'ViT_B_16', | ||
BACKBONE_MODEL_WEIGHTS = '../checkpoints/ViT_B_16_SEISMIC_SGD_28G_M75.pth', | ||
LATENT_SPACE_DIM = 16, | ||
FC_IN_FEATURES = 768) |
Empty file.
Oops, something went wrong.