Skip to content

Commit

Permalink
Merge pull request #36 from johndpope/fix/refactor-training
Browse files Browse the repository at this point in the history
Fix/refactor training
  • Loading branch information
johndpope authored Sep 3, 2024
2 parents 6d5b80e + b2b51e1 commit 1807bf6
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 134 deletions.
Binary file modified __pycache__/WebVid10M.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/lia_resblocks.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/loss.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/resblocks.cpython-311.pyc
Binary file not shown.
103 changes: 103 additions & 0 deletions hog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
from scipy.stats import skew, kurtosis
from skimage.feature import hog
import cv2
import xgboost as xgb


def train_xgboost_anomaly_detector(normal_features, anomalous_features):
"""Train XGBoost model for anomaly detection."""
X = np.vstack([normal_features, anomalous_features])
y = np.hstack([np.zeros(len(normal_features)), np.ones(len(anomalous_features))])

dtrain = xgb.DMatrix(X, label=y)

params = {
'max_depth': 3,
'eta': 0.1,
'objective': 'binary:logistic',
'eval_metric': 'auc'
}

num_round = 100
bst = xgb.train(params, dtrain, num_round)

return bst

def compute_optical_flow_stats(prev_frame, curr_frame):
"""Compute optical flow statistics between two frames."""
prev_np = prev_frame.cpu().numpy().transpose(1, 2, 0)
curr_np = curr_frame.cpu().numpy().transpose(1, 2, 0)

prev_gray = cv2.cvtColor(prev_np, cv2.COLOR_RGB2GRAY)
curr_gray = cv2.cvtColor(curr_np, cv2.COLOR_RGB2GRAY)

flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)

magnitude, angle = cv2.cartToPolar(flow[..., 0], flow[..., 1])

stats = {
'mean_magnitude': np.mean(magnitude),
'std_magnitude': np.std(magnitude),
'mean_angle': np.mean(angle),
'std_angle': np.std(angle)
}
return np.array(list(stats.values()))


def compute_hog_features(frame, orientations=9, pixels_per_cell=(8, 8), cells_per_block=(2, 2)):
"""Compute HOG features for a frame."""
frame_np = frame.cpu().numpy().transpose(1, 2, 0) # Change to HWC format
features = hog(frame_np, orientations=orientations, pixels_per_cell=pixels_per_cell,
cells_per_block=cells_per_block, channel_axis=-1)
return features

def compute_pixel_statistics(frame):
"""Compute statistical measures of pixel values."""
frame_np = frame.cpu().numpy()
mean = np.mean(frame_np, axis=(1, 2))
variance = np.var(frame_np, axis=(1, 2))
skewness = skew(frame_np, axis=(1, 2))
kurt = kurtosis(frame_np, axis=(1, 2))
return np.concatenate([mean, variance, skewness, kurt])

def compute_color_histogram(frame, bins=32):
"""Compute color histogram for each channel."""
histograms = []
for channel in range(frame.shape[0]):
hist = torch.histc(frame[channel], bins=bins, min=0, max=1)
histograms.append(hist)
return torch.cat(histograms)

def get_latent_representation(frame, imf_model):
"""Extract latent representation from IMF model's encoder."""
with torch.no_grad():
latent = imf_model.latent_token_encoder(frame.unsqueeze(0))
return latent.squeeze(0)


def compute_reconstruction_error(frame, reference_frame, imf_model):
"""Compute reconstruction error using the IMF model."""
with torch.no_grad():
reconstructed = imf_model(frame.unsqueeze(0), reference_frame.unsqueeze(0))[0]

mse = F.mse_loss(reconstructed, frame.unsqueeze(0))
return mse.item()

def extract_frame_features(curr_frame, prev_frame, reference_frame, imf_model):
"""Extract all features for a given frame."""
pixel_stats = compute_pixel_statistics(curr_frame)
hog_features = compute_hog_features(curr_frame)
color_hist = compute_color_histogram(curr_frame)
flow_stats = compute_optical_flow_stats(prev_frame, curr_frame)
latent_rep = get_latent_representation(curr_frame, imf_model)
recon_error = compute_reconstruction_error(curr_frame, reference_frame, imf_model)

return np.concatenate([
pixel_stats,
hog_features,
color_hist.cpu().numpy(),
flow_stats,
latent_rep.cpu().numpy(),
[recon_error]
])
65 changes: 33 additions & 32 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vit import ImplicitMotionAlignment
from resblocks import FeatResBlock,UpConvResBlock,DownConvResBlock
from lia_resblocks import StyledConv,EqualConv2d,EqualLinear,ResBlock # these are correct https://github.com/hologerry/IMF/issues/4 "You can refer to this repo https://github.com/wyhsirius/LIA/ for StyleGAN2 related code, such as Encoder, Decoder."

from helper import normalize

from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import math
Expand Down Expand Up @@ -390,49 +390,50 @@ def style_mixing(self, t_c, t_r):
def forward(self, x_current, x_reference):
x_current = x_current.requires_grad_()
x_reference = x_reference.requires_grad_()

# Dense feature encoding
# Forward Pass
f_r = self.dense_feature_encoder(x_reference)

# Latent token encoding
t_r = self.latent_token_encoder(x_reference)
t_c = self.latent_token_encoder(x_current)

# StyleGAN2-like mapping network
t_r = self.mapping_network(t_r)
t_c = self.mapping_network(t_c)

# Add noise to latent tokens
t_r = self.add_noise(t_r)
t_c = self.add_noise(t_c)
# noise_r = torch.randn_like(t_r) * self.noise_magnitude "no noise"
# noise_c = torch.randn_like(t_c) * self.noise_magnitude
# t_r = t_r + noise_r
# t_c = t_c + noise_c

# if torch.rand(()).item() < self.style_mixing_prob:
# batch_size = t_c.size(0)
# rand_indices = torch.randperm(batch_size)
# rand_t_c = t_c[rand_indices]
# rand_t_r = t_r[rand_indices]
# mix_mask = torch.rand(batch_size, 1, device=t_c.device) < 0.5
# mix_mask = mix_mask.float()
# mix_t_c = t_c * mix_mask + rand_t_c * (1 - mix_mask)
# mix_t_r = t_r * mix_mask + rand_t_r * (1 - mix_mask)
# else:
mix_t_c = t_c
mix_t_r = t_r

m_c = self.latent_token_decoder(mix_t_c)
m_r = self.latent_token_decoder(mix_t_r)

# Apply style mixing
t_c, t_r = self.style_mixing(t_c, t_r)

# Latent token decoding
m_r = self.latent_token_decoder(t_r)
m_c = self.latent_token_decoder(t_c)

# Implicit motion alignment with noise injection
aligned_features = []
for i in range(len(self.implicit_motion_alignment)):
f_r_i = f_r[i]
m_r_i = self.noise_injection(m_r[i])
m_c_i = self.noise_injection(m_c[i])
align_layer = self.implicit_motion_alignment[i]
m_c_i = m_c[i]
m_r_i = m_r[i]
aligned_feature = align_layer(m_c_i, m_r_i, f_r_i)
aligned_features.append(aligned_feature)


# Frame decoding
reconstructed_frame = self.frame_decoder(aligned_features)

return reconstructed_frame, {
'dense_features': f_r,
'latent_tokens': (t_c, t_r),
'motion_features': (m_c, m_r),
'aligned_features': aligned_features
}
x_reconstructed = self.frame_decoder(aligned_features)
x_reconstructed = normalize(x_reconstructed)
return x_reconstructed
# return reconstructed_frame, {
# 'dense_features': f_r,
# 'latent_tokens': (t_c, t_r),
# 'motion_features': (m_c, m_r),
# 'aligned_features': aligned_features
# }

def set_noise_level(self, noise_level):
self.noise_level = noise_level
Expand Down
Loading

0 comments on commit 1807bf6

Please sign in to comment.