From d166f2f25053694cbba16e7d628a031aa1b83dfc Mon Sep 17 00:00:00 2001 From: MarcusLoppe Date: Mon, 17 Jun 2024 18:11:08 +0200 Subject: [PATCH] dataset improvements --- meshgpt_pytorch/mesh_dataset.py | 59 +++++++++++++++++++++--------- meshgpt_pytorch/meshgpt_pytorch.py | 4 ++ 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/meshgpt_pytorch/mesh_dataset.py b/meshgpt_pytorch/mesh_dataset.py index 1360171..0a40397 100644 --- a/meshgpt_pytorch/mesh_dataset.py +++ b/meshgpt_pytorch/mesh_dataset.py @@ -2,6 +2,7 @@ import numpy as np from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm +import torch from meshgpt_pytorch import ( MeshAutoencoder, MeshTransformer @@ -75,15 +76,34 @@ def sort_dataset_keys(self): {key: d[key] for key in desired_order if key in d} for d in self.data ] - def generate_face_edges(self): - i = 0 - for item in self.data: - if 'face_edges' not in item: - item['face_edges'] = derive_face_edges_from_faces(item['faces']) - i += 1 + def generate_face_edges(self, batch_size = 5): + data_to_process = [item for item in self.data if 'faces_edges' not in item] + + total_batches = (len(data_to_process) + batch_size - 1) // batch_size + device = "cuda" if torch.cuda.is_available() else "cpu" + + for i in tqdm(range(0, len(data_to_process), batch_size), total=total_batches): + batch_data = data_to_process[i:i+batch_size] + + if not batch_data: + continue + padded_batch_faces = pad_sequence( + [item['faces'] for item in batch_data], + batch_first=True, + padding_value=-1 + ).to(device) + + batched_faces_edges = derive_face_edges_from_faces(padded_batch_faces, pad_id=-1) + + mask = (batched_faces_edges != -1).all(dim=-1) + for item_idx, (item_edges, item_mask) in enumerate(zip(batched_faces_edges, mask)): + item_edges_masked = item_edges[item_mask] + item = batch_data[item_idx] + item['face_edges'] = item_edges_masked + self.sort_dataset_keys() - print(f"[MeshDataset] Generated face_edges for {i}/{len(self.data)} entries") + print(f"[MeshDataset] Generated face_edges for {len(data_to_process)} entries") def generate_codes(self, autoencoder : MeshAutoencoder, batch_size = 25): total_batches = (len(self.data) + batch_size - 1) // batch_size @@ -91,9 +111,9 @@ def generate_codes(self, autoencoder : MeshAutoencoder, batch_size = 25): for i in tqdm(range(0, len(self.data), batch_size), total=total_batches): batch_data = self.data[i:i+batch_size] - padded_batch_vertices = pad_sequence([item['vertices'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id) - padded_batch_faces = pad_sequence([item['faces'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id) - padded_batch_face_edges = pad_sequence([item['face_edges'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id) + padded_batch_vertices = pad_sequence([item['vertices'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device) + padded_batch_faces = pad_sequence([item['faces'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device) + padded_batch_face_edges = pad_sequence([item['face_edges'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device) batch_codes = autoencoder.tokenize( vertices=padded_batch_vertices, @@ -112,17 +132,22 @@ def generate_codes(self, autoencoder : MeshAutoencoder, batch_size = 25): print(f"[MeshDataset] Generated codes for {len(self.data)} entries") def embed_texts(self, transformer : MeshTransformer, batch_size = 50): - unique_texts = set(item['texts'] for item in self.data) + unique_texts = list(set(item['texts'] for item in self.data)) embeddings = [] - for i in range(0,len(unique_texts), batch_size): - text_embeddings = transformer.embed_texts(list(unique_texts)[i:i+batch_size]) - embeddings.extend(text_embeddings) + text_embedding_dict = {} + for i in tqdm(range(0,len(unique_texts), batch_size)): + batch_texts = unique_texts[i:i+batch_size] + text_embeddings = transformer.embed_texts(batch_texts) + mask = (text_embeddings != transformer.conditioner.text_embed_pad_value).all(dim=-1) - text_embedding_dict = dict(zip(unique_texts, embeddings)) - - for item in self.data: + for idx, text in enumerate(batch_texts): + masked_embedding = text_embeddings[idx][mask[idx]] + text_embedding_dict[text] = masked_embedding + + for item in self.data: if 'texts' in item: item['text_embeds'] = text_embedding_dict.get(item['texts'], None) del item['texts'] + self.sort_dataset_keys() print(f"[MeshDataset] Generated {len(embeddings)} text_embeddings") diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index efb653d..3094394 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -635,6 +635,10 @@ def __init__( self.commit_loss_weight = commit_loss_weight self.bin_smooth_blur_sigma = bin_smooth_blur_sigma + @property + def device(self): + return next(self.parameters()).device + @classmethod def _from_pretrained( cls,