diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index 451bec3..ab4e330 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -157,10 +157,18 @@ def derive_angle(x, y, eps = 1e-5): def get_derived_face_features( face_coords: Float['b nf nvf 3'] # 3 or 4 vertices with 3 coordinates ): - shifted_face_coords = torch.cat((face_coords[:, :, -1:], face_coords[:, :, :-1]), dim = 2) + is_quad = face_coords.shape[-2] == 4 + + # shift face coordinates depending on triangles or quads + + shifted_face_coords = torch.roll(face_coords, 1, dims = (2,)) angles = derive_angle(face_coords, shifted_face_coords) + if is_quad: + # @sbriseid says quads need to be shifted by 2 + shifted_face_coords = torch.roll(shifted_face_coords, 1, dims = (2,)) + edge1, edge2, *_ = (face_coords - shifted_face_coords).unbind(dim = 2) cross_product = torch.cross(edge1, edge2, dim = -1) diff --git a/meshgpt_pytorch/version.py b/meshgpt_pytorch/version.py index 7bf3831..b0ff349 100644 --- a/meshgpt_pytorch/version.py +++ b/meshgpt_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.5.7' +__version__ = '1.5.12' diff --git a/setup.py b/setup.py index 5da26ce..437d381 100644 --- a/setup.py +++ b/setup.py @@ -31,10 +31,11 @@ 'environs', 'gateloop-transformer>=0.2.2', 'jaxtyping', - 'local-attention>=1.9.0', + 'local-attention>=1.9.11', 'numpy', 'pytorch-custom-utils>=0.0.9', 'rotary-embedding-torch>=0.6.4', + 'sentencepiece', 'taylor-series-linear-attention>=0.1.6', 'torch>=2.1', 'torch_geometric',