Skip to content

Commit

Permalink
Merge pull request #1 from mmcint/main
Browse files Browse the repository at this point in the history
Adding SectionSeeker Code, ReadME, and Quick start notebook
  • Loading branch information
jessepisel authored Dec 4, 2024
2 parents b63d51d + 54bf100 commit 8658e4b
Show file tree
Hide file tree
Showing 19 changed files with 1,836 additions and 0 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Welcome to Section Seeker!

![Section-Seeker-Logo](assets/reflect_connect.png)

#### **This notebook and package has been adapted from ThinkOnward's [Reflection Connection Challenge](https://thinkonward.com/app/c/challenges/reflection-connection), which ran in late 2023. The SectionSeeker can be used to train a SiameseNN to identify similar sections to the one a user inputs. This can be extremely useful for seismic interpreters looking for an analog section or basin.**


#### Background

Siamese Neural Networks (SNN) have shown great skill at one-shot learning collections of various images. This challenge asks you to train an algorithm to find similar-looking images of seismic data within a larger corpus using a limited training for eight categories. Your solution will need to match as many different features using these data. This challenge is experimental, so we are keen to see how different participants utilize this framework to build a solution.

To non-geophysicists, seismic images are mysterious: lots of black-and-white squiggly lines stacked on one another. However, with more experience, different features in the seismic can be identified. These features represent common geology structures: a river channel, a salt pan, or a fault. Recognizing seismic features is no different from when a medical technician recognizes the difference between a heart valve or major artery on an echocardiogram. A geoscientist combines all these features into a hypothesis about how the Earth developed in the survey area. An algorithm that can identify parts of a seismic image will enable geoscientists to build more robust hypotheses and spend more time integrating other pieces of information into a comprehensive model of the Earth.

#### Getting Started

Check out the starter notebook for help getting your own SNN up and running!
704 changes: 704 additions & 0 deletions SectionSeeker_Quickstart_notebook.ipynb

Large diffs are not rendered by default.

Binary file added assets/reflect_connect.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
albumentations==1.3.1
contextlib2==21.6.0
joblib==1.3.2
multiprocess==0.70.15
numba==0.58.1
numpy==1.26.1
nvgpu==0.10.0
nvidia-ml-py==12.535.108
opencv-python==4.8.1.78
packaging==21.3
pandas==1.5.3
pathos==0.3.1
Pillow==10.1.0
py4j==0.10.9.5
pyarrow==13.0.0
pyfunctional==1.4.3
PyYAML
safetensors==0.4.0
scikit-image==0.22.0
scikit-learn==1.3.2
scipy
torch==2.0.1
torchvision==0.15.2
tqdm
ujson==5.8.0
Werkzeug==3.0.1
matplotlib==3.8.0
61 changes: 61 additions & 0 deletions src/results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import json
import numpy as np

class ResultBuilder:
def __init__(self):
self.results = dict()

def build(self,
query_image_labels: np.ndarray,
matched_labels: np.ndarray,
confidence_scores: np.ndarray):
"""
Prepare results in expected form
:param query_image_labels: numpy array of N reference image labels with shape [N]
:param matched_labels: numpy array of labels of matched base images. Given N query images, this should have shape (N, 3)
:param confidence_scores: numpy array of confidence scores for each matched based image. Given N query images, this should have shape (N, 3)
"""

# validate shapes of inputs
if len(query_image_labels.shape) != 1:
raise ValueError(f'Expected query_image_labels to be 1-dimensional array, got {query_image_labels.shape} instead')

if matched_labels.shape != (query_image_labels.shape[0],3):
raise ValueError(f'Expected matched_labels to have shape {(query_image_labels.shape[0], 3)}, got {matched_labels.shape} instead')

if confidence_scores.shape != (query_image_labels.shape[0],3):
raise ValueError(f'Expected confidence_scores to have shape {(query_image_labels.shape[0], 3)}, got {confidence_scores.shape} instead')

for i, x in enumerate(query_image_labels):
labels = matched_labels[i]
confidence = confidence_scores[i]

result_x = [{'label': labels[j], 'confidence': confidence[j]} for j in range(0,3)]

self.results.update({x: result_x})

return self

def to_json(self, path: str = '.') -> None:
"""
Save results to json file
:param path: parent directory of result.json file
"""

path = f'{path}/results.json'
with open(path, 'w+') as f:
json.dump(self.results, f)

def __call__(self,
query_image_labels: np.ndarray,
matched_labels: np.ndarray,
confidence_scores: np.ndarray,
path: str = '.') -> None:
"""
Build result and save results to json file
"""
self.build(query_image_labels, matched_labels, confidence_scores)
self.to_json(path)

157 changes: 157 additions & 0 deletions src/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import codecs
import os
import os.path
import shutil
import string
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence

import numpy as np
from scipy.spatial import KDTree
import torch
from PIL import Image
import glob
import cv2

from snn.utils import detect_device
from snn.model import SiameseNetwork


class ImageSet:
"""
Subscriptapble dataset-like class for loading, storing and processing image collections
:param root: Path to project root directory, which contains data/image_corpus/ or data/query catalog
:param base: Build ImageSet on top of image_corpus if True, else on top of query catalog
:param build: Build ImageSet from filesystem instead of using saved version
:param transform: Callable that will be applied to all images when calling __getitem__() method
:param compatibility_mode: Convert images to PIL.Image before applying transform and returning from __getitime__() method
:param greyscale: Load images in grayscale if True, else use 3-channel RGB
:param normalize: If True, images will be normalized image-wise when loaded from disk
"""
def __init__(self,
root: str,
base: bool = True,
build: bool = False,
transform: Callable = None,
compatibility_mode: bool = False,
greyscale: bool = False,
normalize: bool = True) -> None:

self.root = root
self.compatibility_mode = compatibility_mode
self.greyscale = greyscale
self.colormode = 'L' if greyscale else 'RGB'
self.transform = transform
self.base = base
self.normalize = normalize

if build:
self.embeddings = []
self.data, self.names = self._build()
return

self.data = self._load()


def _build(self) -> Tuple[torch.Tensor, str]:

dirpath = f"{self.root}/data/{'image_corpus' if self.base else 'query'}"
data = []
images = []
names = []
for filename in glob.glob(f"{dirpath}/*png"):
im = Image.open(filename)
# resize into common shape
im = im.convert(self.colormode).resize((118, 143))
if self.normalize:
im = cv2.normalize(np.array(im), None, 0.0, 1.0, cv2.NORM_MINMAX, cv2.CV_32FC1)
image = np.array(im, dtype=np.float32)
fname = filename.split('/')[-1]
data.append(image)
names.append(fname)
return torch.from_numpy(np.array(data)), names

def _load(self) -> Tuple[torch.Tensor, str]:
...

def save(self) -> None:
...

def build_embeddings(self, model: SiameseNetwork, device: torch.cuda.device = None):

if device is None:
device = detect_device()

with torch.no_grad():
model.eval()
for img, name in self:
img_input = img.transpose(2,0).transpose(2,1).to(device).unsqueeze(0)
embedding = model.get_embedding(img_input)
self.embeddings.append((embedding, name))

return self

def get_embeddings(self) -> List[Tuple[torch.Tensor, str]]:
if self.embeddings is None:
raise RuntimeError('Embedding collection is empty. Run self.build_embeddings() method to build it')

return self.embeddings

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img = self.data[index]
name = self.names[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
if self.compatibility_mode:
img = Image.fromarray(img.numpy(), mode=self.colormode)

if self.transform is not None:
img = self.transform(img)

return img, name


class SearchTree:
"""
Wrapper for k-d tree built on image embeddings
:param query_set: instance of base ImageSet with built embedding representation
"""
def __init__(self, query_set: ImageSet) -> None:
embeddings = query_set.get_embeddings()
self.embeddings = np.concatenate([x[0].cpu().numpy() for x in embeddings], axis=0)
self.names = np.array([x[1] for x in embeddings])
self.kdtree = self._build_kdtree()

def _build_kdtree(self) -> KDTree:
print('Building KD-Tree from embeddings')
return KDTree(self.embeddings)

def query(self, anchors: ImageSet, k: int = 3) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Search for k nearest neighbors of provided anchor embeddings
:param anchors: instance of query (reference) ImageSet with built embedding representation
:returns: tuple of reference_labels, distances to matched label embeddings, matched label embeddings, matched_labels
"""

reference = anchors.get_embeddings()
reference_embeddings = np.concatenate([x[0].cpu().numpy() for x in reference], axis=0)
reference_labels = np.array([x[1] for x in reference])

distances, indices = self.kdtree.query(reference_embeddings, k=k, workers=-1)
return reference_labels, distances, self.embeddings[indices], self.names[indices]

def __call__(self, *args, **kwargs) -> Any:
return self.query(*args, **kwargs)

Empty file added src/snn/__init__.py
Empty file.
Binary file added src/snn/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
21 changes: 21 additions & 0 deletions src/snn/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

@dataclass
class ModelConfig:
BACKBONE_MODEL: str = 'ResNet50'
BACKBONE_MODEL_WEIGHTS: str = 'ResNet50_Weights.IMAGENET1K_V2'
LATENT_SPACE_DIM: int = 8
FC_IN_FEATURES: int = -1


defaultConfig = ModelConfig()

vitBaseConfig = ModelConfig(BACKBONE_MODEL = 'ViT_B_16',
BACKBONE_MODEL_WEIGHTS = 'ViT_B_16_Weights.DEFAULT',
LATENT_SPACE_DIM = 16,
FC_IN_FEATURES = 768)

vitBaseConfigPretrained = ModelConfig(BACKBONE_MODEL = 'ViT_B_16',
BACKBONE_MODEL_WEIGHTS = '../checkpoints/ViT_B_16_SEISMIC_SGD_28G_M75.pth',
LATENT_SPACE_DIM = 16,
FC_IN_FEATURES = 768)
Empty file added src/snn/dataset/__init__.py
Empty file.
Loading

0 comments on commit 8658e4b

Please sign in to comment.