diff --git a/__pycache__/WebVid10M.cpython-311.pyc b/__pycache__/WebVid10M.cpython-311.pyc index e69ff7f..ae3e5bc 100644 Binary files a/__pycache__/WebVid10M.cpython-311.pyc and b/__pycache__/WebVid10M.cpython-311.pyc differ diff --git a/__pycache__/lia_resblocks.cpython-311.pyc b/__pycache__/lia_resblocks.cpython-311.pyc index 9c762d4..41963ba 100644 Binary files a/__pycache__/lia_resblocks.cpython-311.pyc and b/__pycache__/lia_resblocks.cpython-311.pyc differ diff --git a/__pycache__/loss.cpython-311.pyc b/__pycache__/loss.cpython-311.pyc index cead189..dcad760 100644 Binary files a/__pycache__/loss.cpython-311.pyc and b/__pycache__/loss.cpython-311.pyc differ diff --git a/__pycache__/resblocks.cpython-311.pyc b/__pycache__/resblocks.cpython-311.pyc index 080b11f..7faa834 100644 Binary files a/__pycache__/resblocks.cpython-311.pyc and b/__pycache__/resblocks.cpython-311.pyc differ diff --git a/hog.py b/hog.py new file mode 100644 index 0000000..d8b9c4d --- /dev/null +++ b/hog.py @@ -0,0 +1,103 @@ +import numpy as np +from scipy.stats import skew, kurtosis +from skimage.feature import hog +import cv2 +import xgboost as xgb + + +def train_xgboost_anomaly_detector(normal_features, anomalous_features): + """Train XGBoost model for anomaly detection.""" + X = np.vstack([normal_features, anomalous_features]) + y = np.hstack([np.zeros(len(normal_features)), np.ones(len(anomalous_features))]) + + dtrain = xgb.DMatrix(X, label=y) + + params = { + 'max_depth': 3, + 'eta': 0.1, + 'objective': 'binary:logistic', + 'eval_metric': 'auc' + } + + num_round = 100 + bst = xgb.train(params, dtrain, num_round) + + return bst + +def compute_optical_flow_stats(prev_frame, curr_frame): + """Compute optical flow statistics between two frames.""" + prev_np = prev_frame.cpu().numpy().transpose(1, 2, 0) + curr_np = curr_frame.cpu().numpy().transpose(1, 2, 0) + + prev_gray = cv2.cvtColor(prev_np, cv2.COLOR_RGB2GRAY) + curr_gray = cv2.cvtColor(curr_np, cv2.COLOR_RGB2GRAY) + + flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + magnitude, angle = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + + stats = { + 'mean_magnitude': np.mean(magnitude), + 'std_magnitude': np.std(magnitude), + 'mean_angle': np.mean(angle), + 'std_angle': np.std(angle) + } + return np.array(list(stats.values())) + + +def compute_hog_features(frame, orientations=9, pixels_per_cell=(8, 8), cells_per_block=(2, 2)): + """Compute HOG features for a frame.""" + frame_np = frame.cpu().numpy().transpose(1, 2, 0) # Change to HWC format + features = hog(frame_np, orientations=orientations, pixels_per_cell=pixels_per_cell, + cells_per_block=cells_per_block, channel_axis=-1) + return features + +def compute_pixel_statistics(frame): + """Compute statistical measures of pixel values.""" + frame_np = frame.cpu().numpy() + mean = np.mean(frame_np, axis=(1, 2)) + variance = np.var(frame_np, axis=(1, 2)) + skewness = skew(frame_np, axis=(1, 2)) + kurt = kurtosis(frame_np, axis=(1, 2)) + return np.concatenate([mean, variance, skewness, kurt]) + +def compute_color_histogram(frame, bins=32): + """Compute color histogram for each channel.""" + histograms = [] + for channel in range(frame.shape[0]): + hist = torch.histc(frame[channel], bins=bins, min=0, max=1) + histograms.append(hist) + return torch.cat(histograms) + +def get_latent_representation(frame, imf_model): + """Extract latent representation from IMF model's encoder.""" + with torch.no_grad(): + latent = imf_model.latent_token_encoder(frame.unsqueeze(0)) + return latent.squeeze(0) + + +def compute_reconstruction_error(frame, reference_frame, imf_model): + """Compute reconstruction error using the IMF model.""" + with torch.no_grad(): + reconstructed = imf_model(frame.unsqueeze(0), reference_frame.unsqueeze(0))[0] + + mse = F.mse_loss(reconstructed, frame.unsqueeze(0)) + return mse.item() + +def extract_frame_features(curr_frame, prev_frame, reference_frame, imf_model): + """Extract all features for a given frame.""" + pixel_stats = compute_pixel_statistics(curr_frame) + hog_features = compute_hog_features(curr_frame) + color_hist = compute_color_histogram(curr_frame) + flow_stats = compute_optical_flow_stats(prev_frame, curr_frame) + latent_rep = get_latent_representation(curr_frame, imf_model) + recon_error = compute_reconstruction_error(curr_frame, reference_frame, imf_model) + + return np.concatenate([ + pixel_stats, + hog_features, + color_hist.cpu().numpy(), + flow_stats, + latent_rep.cpu().numpy(), + [recon_error] + ]) diff --git a/model.py b/model.py index d102222..dbe462f 100644 --- a/model.py +++ b/model.py @@ -10,7 +10,7 @@ from vit import ImplicitMotionAlignment from resblocks import FeatResBlock,UpConvResBlock,DownConvResBlock from lia_resblocks import StyledConv,EqualConv2d,EqualLinear,ResBlock # these are correct https://github.com/hologerry/IMF/issues/4 "You can refer to this repo https://github.com/wyhsirius/LIA/ for StyleGAN2 related code, such as Encoder, Decoder." - +from helper import normalize from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights import math @@ -390,49 +390,50 @@ def style_mixing(self, t_c, t_r): def forward(self, x_current, x_reference): x_current = x_current.requires_grad_() x_reference = x_reference.requires_grad_() - - # Dense feature encoding + # Forward Pass f_r = self.dense_feature_encoder(x_reference) - - # Latent token encoding t_r = self.latent_token_encoder(x_reference) t_c = self.latent_token_encoder(x_current) - # StyleGAN2-like mapping network - t_r = self.mapping_network(t_r) - t_c = self.mapping_network(t_c) - - # Add noise to latent tokens - t_r = self.add_noise(t_r) - t_c = self.add_noise(t_c) + # noise_r = torch.randn_like(t_r) * self.noise_magnitude "no noise" + # noise_c = torch.randn_like(t_c) * self.noise_magnitude + # t_r = t_r + noise_r + # t_c = t_c + noise_c + + # if torch.rand(()).item() < self.style_mixing_prob: + # batch_size = t_c.size(0) + # rand_indices = torch.randperm(batch_size) + # rand_t_c = t_c[rand_indices] + # rand_t_r = t_r[rand_indices] + # mix_mask = torch.rand(batch_size, 1, device=t_c.device) < 0.5 + # mix_mask = mix_mask.float() + # mix_t_c = t_c * mix_mask + rand_t_c * (1 - mix_mask) + # mix_t_r = t_r * mix_mask + rand_t_r * (1 - mix_mask) + # else: + mix_t_c = t_c + mix_t_r = t_r + + m_c = self.latent_token_decoder(mix_t_c) + m_r = self.latent_token_decoder(mix_t_r) - # Apply style mixing - t_c, t_r = self.style_mixing(t_c, t_r) - - # Latent token decoding - m_r = self.latent_token_decoder(t_r) - m_c = self.latent_token_decoder(t_c) - - # Implicit motion alignment with noise injection aligned_features = [] for i in range(len(self.implicit_motion_alignment)): f_r_i = f_r[i] - m_r_i = self.noise_injection(m_r[i]) - m_c_i = self.noise_injection(m_c[i]) align_layer = self.implicit_motion_alignment[i] + m_c_i = m_c[i] + m_r_i = m_r[i] aligned_feature = align_layer(m_c_i, m_r_i, f_r_i) aligned_features.append(aligned_feature) - - # Frame decoding - reconstructed_frame = self.frame_decoder(aligned_features) - - return reconstructed_frame, { - 'dense_features': f_r, - 'latent_tokens': (t_c, t_r), - 'motion_features': (m_c, m_r), - 'aligned_features': aligned_features - } + x_reconstructed = self.frame_decoder(aligned_features) + x_reconstructed = normalize(x_reconstructed) + return x_reconstructed + # return reconstructed_frame, { + # 'dense_features': f_r, + # 'latent_tokens': (t_c, t_r), + # 'motion_features': (m_c, m_r), + # 'aligned_features': aligned_features + # } def set_noise_level(self, noise_level): self.noise_level = noise_level diff --git a/train.py b/train.py index 77c6aae..800a86d 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,8 @@ import torchvision.models as models from loss import gan_loss_fn,MediaPipeEyeEnhancementLoss # from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.optim.lr_scheduler import CosineAnnealingLR +# from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import OneCycleLR import random from stylegan import EMA @@ -91,10 +92,9 @@ def __init__(self, config, model, discriminator, train_dataloader, accelerator): self.optimizer_d = AdamW(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999)) # Learning rate schedulers - # self.scheduler_g = CosineAnnealingLR(self.optimizer_g, T_max=100, eta_min=1e-6) - # self.scheduler_d = CosineAnnealingLR(self.optimizer_d, T_max=100, eta_min=1e-6) - self.scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer_g, mode='min', factor=0.5, patience=5, verbose=True) - self.scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer_d, mode='min', factor=0.5, patience=5, verbose=True) + total_steps = config.training.num_epochs * len(train_dataloader) + self.scheduler_g = OneCycleLR(self.optimizer_g, max_lr=2e-4, total_steps=total_steps) + self.scheduler_d = OneCycleLR(self.optimizer_d, max_lr=2e-4, total_steps=total_steps) if config.training.use_ema: @@ -117,70 +117,18 @@ def check_exploding_gradients(self, model): return True return False - def adjust_learning_rate(self, optimizer, factor=0.1, min_lr=1e-6): - for param_group in optimizer.param_groups: - param_group['lr'] = max(param_group['lr'] * factor, min_lr) - print(f"🔥 Adjusted learning rate. New LR: {optimizer.param_groups[0]['lr']}") - - - def train_step(self, x_current, x_reference,global_step): + def train_step(self, x_current, x_reference, global_step): if x_current.nelement() == 0: print("🔥 Skipping training step due to empty x_current") return None, None, None, None, None, None - - - # Forward Pass - f_r = self.model.dense_feature_encoder(x_reference) - t_r = self.model.latent_token_encoder(x_reference) - t_c = self.model.latent_token_encoder(x_current) - - noise_r = torch.randn_like(t_r) * self.noise_magnitude - noise_c = torch.randn_like(t_c) * self.noise_magnitude - t_r = t_r + noise_r - t_c = t_c + noise_c - - if torch.rand(()).item() < self.style_mixing_prob: - batch_size = t_c.size(0) - rand_indices = torch.randperm(batch_size) - rand_t_c = t_c[rand_indices] - rand_t_r = t_r[rand_indices] - mix_mask = torch.rand(batch_size, 1, device=t_c.device) < 0.5 - mix_mask = mix_mask.float() - mix_t_c = t_c * mix_mask + rand_t_c * (1 - mix_mask) - mix_t_r = t_r * mix_mask + rand_t_r * (1 - mix_mask) - else: - mix_t_c = t_c - mix_t_r = t_r - m_c = self.model.latent_token_decoder(mix_t_c) - m_r = self.model.latent_token_decoder(mix_t_r) + # Generate reconstructed frame + x_reconstructed = self.model(x_current, x_reference) - aligned_features = [] - for i in range(len(self.model.implicit_motion_alignment)): - f_r_i = f_r[i] - align_layer = self.model.implicit_motion_alignment[i] - m_c_i = m_c[i] - m_r_i = m_r[i] - aligned_feature = align_layer(m_c_i, m_r_i, f_r_i) - aligned_features.append(aligned_feature) - - x_reconstructed = self.model.frame_decoder(aligned_features) - x_reconstructed = normalize(x_reconstructed) - - - # eye loss - # l_eye = self.eye_loss_fn(x_reconstructed, x_current) - if self.config.training.use_subsampling: - sub_sample_size = (128, 128) # As mentioned in the paper + sub_sample_size = (128, 128) # As mentioned in the paper https://github.com/johndpope/MegaPortrait-hack/issues/41 x_current, x_reconstructed = consistent_sub_sample(x_current, x_reconstructed, sub_sample_size) - - if global_step % self.config.logging.sample_every == 0: - save_image(x_reconstructed, 'x_reconstructed.png', normalize=True) - save_image(x_current, 'x_current.png', normalize=True) - save_image(x_reference, 'x_reference.png', normalize=True) - # Discriminator updates d_loss_total = 0 for _ in range(self.config.training.n_critic): @@ -220,7 +168,6 @@ def train_step(self, x_current, x_reference,global_step): d_loss_total += d_loss.item() - # Average discriminator loss d_loss_avg = d_loss_total / self.config.training.n_critic @@ -241,19 +188,27 @@ def train_step(self, x_current, x_reference,global_step): self.accelerator.backward(g_loss) if self.check_exploding_gradients(self.model): - print("🔥 Exploding gradients detected. Adjusting learning rate.") - self.adjust_learning_rate(self.optimizer_g) - self.adjust_learning_rate(self.optimizer_d) - self.optimizer_g.zero_grad() - self.optimizer_d.zero_grad() + print("🔥 Exploding gradients detected. Clipping gradients.") + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) else: if self.config.training.clip_grad: torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.config.training.clip_grad_norm) - self.optimizer_g.step() + + self.optimizer_g.step() + + # Step the schedulers + self.scheduler_g.step() + self.scheduler_d.step() if self.ema: self.ema.update() + # Logging - locally for sanity check + if global_step % self.config.logging.sample_every == 0: + save_image(x_reconstructed, f'x_reconstructed.png', normalize=True) + save_image(x_current, f'x_current.png', normalize=True) + save_image(x_reference, f'x_reference.png', normalize=True) + return d_loss_avg, g_loss.item(), l_p.item(), l_v.item(), g_loss_gan.item(), x_reconstructed def train(self, start_epoch=0): @@ -339,9 +294,6 @@ def train(self, start_epoch=0): avg_g_loss = epoch_g_loss / num_valid_steps avg_d_loss = epoch_d_loss / num_valid_steps - # Step the schedulers - self.scheduler_g.step(avg_g_loss) - self.scheduler_d.step(avg_d_loss) @@ -356,42 +308,51 @@ def train(self, start_epoch=0): self.save_checkpoint(epoch, is_final=True) - def load_checkpoint(self, checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location=self.accelerator.device) - - # Load model state - self.model.load_state_dict(checkpoint['model_state_dict']) - - # Load discriminator state - self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) - - # Load optimizer states - self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict']) - self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict']) - - # Load epoch - start_epoch = checkpoint['epoch'] + 1 - - print(f"Loaded checkpoint from epoch {start_epoch - 1}") - return start_epoch - def save_checkpoint(self, epoch, is_final=False): self.accelerator.wait_for_everyone() unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_discriminator = self.accelerator.unwrap_model(self.discriminator) - if is_final: - save_path = f"{self.config.checkpoints.dir}/final_model.pth" - self.accelerator.save(unwrapped_model.state_dict(), save_path) - else: - save_path = f"{self.config.checkpoints.dir}/checkpoint.pth" - self.accelerator.save({ - 'epoch': epoch, - 'model_state_dict': unwrapped_model.state_dict(), - 'discriminator_state_dict': unwrapped_discriminator.state_dict(), - 'optimizer_g_state_dict': self.optimizer_g.state_dict(), - 'optimizer_d_state_dict': self.optimizer_d.state_dict(), - }, save_path) + checkpoint = { + 'epoch': epoch, + 'model_state_dict': unwrapped_model.state_dict(), + 'discriminator_state_dict': unwrapped_discriminator.state_dict(), + 'optimizer_g_state_dict': self.optimizer_g.state_dict(), + 'optimizer_d_state_dict': self.optimizer_d.state_dict(), + 'scheduler_g_state_dict': self.scheduler_g.state_dict(), + 'scheduler_d_state_dict': self.scheduler_d.state_dict(), + } + + if self.ema: + checkpoint['ema_state_dict'] = self.ema.state_dict() + + save_path = f"{self.config.checkpoints.dir}/{'final_model' if is_final else 'checkpoint'}.pth" + self.accelerator.save(checkpoint, save_path) + print(f"Saved checkpoint for epoch {epoch}") + + def load_checkpoint(self, checkpoint_path): + try: + checkpoint = self.accelerator.load(checkpoint_path) + + self.model.load_state_dict(checkpoint['model_state_dict']) + self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) + self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict']) + self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict']) + self.scheduler_g.load_state_dict(checkpoint['scheduler_g_state_dict']) + self.scheduler_d.load_state_dict(checkpoint['scheduler_d_state_dict']) + + if self.ema and 'ema_state_dict' in checkpoint: + self.ema.load_state_dict(checkpoint['ema_state_dict']) + + start_epoch = checkpoint['epoch'] + 1 + print(f"Loaded checkpoint from epoch {start_epoch - 1}") + return start_epoch + except FileNotFoundError: + print(f"No checkpoint found at {checkpoint_path}") + return 0 + except Exception as e: + print(f"Error loading checkpoint: {e}") + return 0 def main(): config = load_config('config.yaml')