Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shubh/refactor #5

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ docs/ION*
.pytest_cache
*.html

*.ipynb
src/simple_problems
notebooks/.ipynb_checkpoints/*
*.out
.python-version
784 changes: 784 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/0.csv

Large diffs are not rendered by default.

779 changes: 779 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/1.csv

Large diffs are not rendered by default.

880 changes: 880 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/10.csv

Large diffs are not rendered by default.

870 changes: 870 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/11.csv

Large diffs are not rendered by default.

818 changes: 818 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/12.csv

Large diffs are not rendered by default.

818 changes: 818 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/13.csv

Large diffs are not rendered by default.

809 changes: 809 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/14.csv

Large diffs are not rendered by default.

833 changes: 833 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/15.csv

Large diffs are not rendered by default.

841 changes: 841 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/16.csv

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/17.csv

Large diffs are not rendered by default.

714 changes: 714 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/2.csv

Large diffs are not rendered by default.

764 changes: 764 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/3.csv

Large diffs are not rendered by default.

762 changes: 762 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/4.csv

Large diffs are not rendered by default.

731 changes: 731 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/5.csv

Large diffs are not rendered by default.

764 changes: 764 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/6.csv

Large diffs are not rendered by default.

835 changes: 835 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/7.csv

Large diffs are not rendered by default.

885 changes: 885 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/8.csv

Large diffs are not rendered by default.

891 changes: 891 additions & 0 deletions data/android_train_processed/2020-05-14-US-MTV-1/Pixel4/9.csv

Large diffs are not rendered by default.

Binary file modified data/baselines/2020-05-14-US-MTV-1/Pixel4_goGPS.mat
Binary file not shown.
56 changes: 56 additions & 0 deletions src/correction_network/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn

"""
Set Attention Block
(elements, batch, dim_in) -> (elements, batch, dim_out) [Flip elements and batch if batch_first = True]
"""
class SAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, batch_first=False):
super().__init__()
self.mab = nn.MultiheadAttention(dim_out, num_heads, batch_first=batch_first)
self.fc_q = nn.Linear(dim_in, dim_out)
self.fc_k = nn.Linear(dim_in, dim_out)
self.fc_v = nn.Linear(dim_in, dim_out)

def forward(self, X, pad_mask=None):
Q = self.fc_q(X)
K, V = self.fc_k(X), self.fc_v(X)
out, wts = self.mab(Q, K, V, key_padding_mask=pad_mask)
return out

"""
Induced Set Attention Block, more efficient than set-attention
(time, batch, dim_in) -> (time, batch, dim_out) [Flip time and batch if batch_first = True]
"""
class ISAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, num_inds, batch_first=False):
super().__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = nn.MultiheadAttention(dim_out, num_heads, kdim=dim_in, vdim=dim_out, batch_first=batch_first)
self.mab1 = nn.MultiheadAttention(dim_in, num_heads, kdim=dim_out, vdim=dim_out, batch_first=batch_first)

def forward(self, X, pad_mask=None):
# Fix batch_first case
H, _ = self.mab0(self.I.repeat(X.size(0), 1, 1), X, X, key_padding_mask=pad_mask)
out, _ = self.mab1(X, H, H, key_padding_mask=pad_mask)
return out


"""
Transformer Encoder block
(time, batch, dim_in) -> (time, batch, dim_out) [Flip time and batch if batch_first = True]
"""

class TEB(nn.Module):
def __init__(self, dim, num_heads, num_layers=1, batch_first=False):
super().__init__()
base_layer = nn.TransformerEncoderLayer(dim, nhead=num_heads, dim_feedforward=2*dim, dropout=0.0, batch_first=batch_first)
self.enc = nn.TransformerEncoder(base_layer, num_layers)

def forward(self, X, pad_mask=None):
out = self.enc(X, src_key_padding_mask=pad_mask)
return out
71 changes: 71 additions & 0 deletions src/correction_network/deepsets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
########################################################################
# Author(s): Shubh Gupta
# Date: 21 September 2021
# Desc: Network models for GNSS-based position corrections
########################################################################
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn


########################################################
# DeepSets (src: https://github.com/yassersouri/pytorch-deep-sets)
class InvariantModel(nn.Module):
def __init__(self, phi: nn.Module, rho: nn.Module):
super().__init__()
self.phi = phi
self.rho = rho

def forward(self, x):
# compute the representation for each data point
x = self.phi.forward(x)
# sum up the representations
x = torch.sum(x, dim=0, keepdim=False)
# compute the output
out = self.rho.forward(x)
return out

class SmallPhi(nn.Module):
def __init__(self, input_size: int, output_size: int = 1, hidden_size: int = 10):
super().__init__()
self.input_size = input_size
self.output_size = output_size

self.fc1 = nn.Linear(self.input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, self.output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x


class SmallRho(nn.Module):
def __init__(self, input_size: int, output_size: int = 1, hidden_size: int = 10):
super().__init__()
self.input_size = input_size
self.output_size = output_size

self.fc1 = nn.Linear(self.input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, self.output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

class DeepSetModel(nn.Module):
def __init__(self, input_size: int, output_size: int = 1, hidden_size: int = 10):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size

phi = SmallPhi(self.input_size, self.hidden_size)
rho = SmallPhi(self.hidden_size, self.output_size)
self.net = InvariantModel(phi, rho)

def forward(self, x, pad_mask=None):
out = self.net.forward(x)
return out
144 changes: 0 additions & 144 deletions src/correction_network/networks.py

This file was deleted.

28 changes: 28 additions & 0 deletions src/correction_network/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn

"""
Pooling via Multi-headed Attention Block
(elements, batch, dim_in) -> (1, batch, dim_out) [Flip elements and batch if batch_first = True]
"""
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, batch_first=False):
super(PMA, self).__init__()
if batch_first:
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
else:
self.S = nn.Parameter(torch.Tensor(num_seeds, 1, dim))
nn.init.xavier_uniform_(self.S)
self.mab = nn.MultiheadAttention(dim, num_heads, batch_first=batch_first)
self.batch_first = batch_first

def forward(self, X, pad_mask=None):
if self.batch_first:
Q = self.S.repeat(X.size(0), 1, 1)
else:
Q = self.S.repeat(1, X.size(1), 1)
out, _ = self.mab(Q, X, X, key_padding_mask=pad_mask)
return out

43 changes: 43 additions & 0 deletions src/correction_network/set_transformer_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
from .attention import *
from .pooling import *

"""
Modified Set Transformer to reduce a set of features to a fixed-dim output
(elements, batch, dim_in) -> (1, batch, dim_out) [Flip elements and batch if batch_first = True]
"""
class SetTransformerPointOutput(torch.nn.Module):
def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=64, num_heads=4, batch_first=False):
super().__init__()

self.feat_in = nn.Sequential(
nn.Linear(dim_input, dim_hidden),
nn.PReLU()
)

self.enc = TEB(dim_hidden, num_heads, num_layers=2, batch_first=batch_first)

self.pool = PMA(dim_hidden, num_heads, num_outputs, batch_first=batch_first)

self.dec = TEB(dim_hidden, num_heads, num_layers=2, batch_first=batch_first)

self.feat_out = nn.Sequential(
nn.Linear(dim_hidden*num_outputs, dim_output)
)
self.batch_first = batch_first

def forward(self, x, pad_mask=None):
x = self.feat_in(x)
x = self.enc(x, pad_mask=pad_mask)
x = self.pool(x, pad_mask=pad_mask)
x = self.dec(x)
if self.batch_first:
x = x.reshape(x.shape[0], -1)
else:
x = x.transpose(1, 0, 2)
x = x.reshape(x.shape[0], -1)
out = self.feat_out(x)
return out