Skip to content

Commit

Permalink
dataset improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusLoppe committed Jun 17, 2024
1 parent bdfcade commit d166f2f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
59 changes: 42 additions & 17 deletions meshgpt_pytorch/mesh_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import torch
from meshgpt_pytorch import (
MeshAutoencoder,
MeshTransformer
Expand Down Expand Up @@ -75,25 +76,44 @@ def sort_dataset_keys(self):
{key: d[key] for key in desired_order if key in d} for d in self.data
]

def generate_face_edges(self):
i = 0
for item in self.data:
if 'face_edges' not in item:
item['face_edges'] = derive_face_edges_from_faces(item['faces'])
i += 1
def generate_face_edges(self, batch_size = 5):
data_to_process = [item for item in self.data if 'faces_edges' not in item]

total_batches = (len(data_to_process) + batch_size - 1) // batch_size
device = "cuda" if torch.cuda.is_available() else "cpu"

for i in tqdm(range(0, len(data_to_process), batch_size), total=total_batches):
batch_data = data_to_process[i:i+batch_size]

if not batch_data:
continue

padded_batch_faces = pad_sequence(
[item['faces'] for item in batch_data],
batch_first=True,
padding_value=-1
).to(device)

batched_faces_edges = derive_face_edges_from_faces(padded_batch_faces, pad_id=-1)

mask = (batched_faces_edges != -1).all(dim=-1)
for item_idx, (item_edges, item_mask) in enumerate(zip(batched_faces_edges, mask)):
item_edges_masked = item_edges[item_mask]
item = batch_data[item_idx]
item['face_edges'] = item_edges_masked

self.sort_dataset_keys()
print(f"[MeshDataset] Generated face_edges for {i}/{len(self.data)} entries")
print(f"[MeshDataset] Generated face_edges for {len(data_to_process)} entries")

def generate_codes(self, autoencoder : MeshAutoencoder, batch_size = 25):
total_batches = (len(self.data) + batch_size - 1) // batch_size

for i in tqdm(range(0, len(self.data), batch_size), total=total_batches):
batch_data = self.data[i:i+batch_size]

padded_batch_vertices = pad_sequence([item['vertices'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id)
padded_batch_faces = pad_sequence([item['faces'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id)
padded_batch_face_edges = pad_sequence([item['face_edges'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id)
padded_batch_vertices = pad_sequence([item['vertices'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device)
padded_batch_faces = pad_sequence([item['faces'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device)
padded_batch_face_edges = pad_sequence([item['face_edges'] for item in batch_data], batch_first=True, padding_value=autoencoder.pad_id).to(autoencoder.device)

batch_codes = autoencoder.tokenize(
vertices=padded_batch_vertices,
Expand All @@ -112,17 +132,22 @@ def generate_codes(self, autoencoder : MeshAutoencoder, batch_size = 25):
print(f"[MeshDataset] Generated codes for {len(self.data)} entries")

def embed_texts(self, transformer : MeshTransformer, batch_size = 50):
unique_texts = set(item['texts'] for item in self.data)
unique_texts = list(set(item['texts'] for item in self.data))
embeddings = []
for i in range(0,len(unique_texts), batch_size):
text_embeddings = transformer.embed_texts(list(unique_texts)[i:i+batch_size])
embeddings.extend(text_embeddings)
text_embedding_dict = {}
for i in tqdm(range(0,len(unique_texts), batch_size)):
batch_texts = unique_texts[i:i+batch_size]
text_embeddings = transformer.embed_texts(batch_texts)
mask = (text_embeddings != transformer.conditioner.text_embed_pad_value).all(dim=-1)

text_embedding_dict = dict(zip(unique_texts, embeddings))

for item in self.data:
for idx, text in enumerate(batch_texts):
masked_embedding = text_embeddings[idx][mask[idx]]
text_embedding_dict[text] = masked_embedding

for item in self.data:
if 'texts' in item:
item['text_embeds'] = text_embedding_dict.get(item['texts'], None)
del item['texts']

self.sort_dataset_keys()
print(f"[MeshDataset] Generated {len(embeddings)} text_embeddings")
4 changes: 4 additions & 0 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ def __init__(
self.commit_loss_weight = commit_loss_weight
self.bin_smooth_blur_sigma = bin_smooth_blur_sigma

@property
def device(self):
return next(self.parameters()).device

@classmethod
def _from_pretrained(
cls,
Expand Down

0 comments on commit d166f2f

Please sign in to comment.