Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the visualization #3

Open
ztbian-bzt opened this issue May 31, 2024 · 9 comments
Open

About the visualization #3

ztbian-bzt opened this issue May 31, 2024 · 9 comments

Comments

@ztbian-bzt
Copy link

Could you consider providing the code to visualize the segmentation results? Just like the figure shows in the paper and Readme markdown. Thanks.

@LucasKre
Copy link
Owner

LucasKre commented Jun 3, 2024

@ztbian-bzt I added an inference script that can be used to predict the labels for a given instance in the dataset and save the output as color-coded meshes (see Readme Inferencing). Afterwards you can inspect the meshes using MeshLab or https://3dviewer.net/)

@ztbian-bzt
Copy link
Author

@LucasKre Thanks a lot for your reply.

@ztbian-bzt ztbian-bzt reopened this Jun 3, 2024
@ztbian-bzt
Copy link
Author

    def get_model():
          return LitDilatedToothSegmentationNetwork() 


    model = get_model()
    if use_gpu:
        model = model.cuda()
    model = model.load_from_checkpoint(ckpt_path)

This coda may lead to the error "TypeError: The classmethod LitDilatedToothSegmentationNetwork.load_from_checkpoint cannot be called on an instance. Please call it on the class type and make sure the return value is used." I change it as follows so that it won't get this error.

    model = LitDilatedToothSegmentationNetwork.load_from_checkpoint(ckpt_path)
    if use_gpu:
        model = model.cuda()

In addition, visualization causes the teeth to elongate. How can I fix this.
data_00OMSZGW_lower

@ztbian-bzt
Copy link
Author

I solved the above problem and thank you for your help.

@supgy
Copy link

supgy commented Jun 13, 2024

I solved the above problem and thank you for your help.

I have encountered the same problem as you, can you teach me how to make it look regular thanks a lot

@mykakus
Copy link

mykakus commented Jul 12, 2024

I have encountered the same problem as you, can you teach me how to make it look regular thanks a lot

You need to reverse normalization steps in PreTransform class (in preprocessing.py).

 # normalize coordinate
            x[:, i] = (x[:, i] - means[i]) / stds[i]  # point 1
            x[:, i + 3] = (x[:, i + 3] - means[i]) / stds[i]  # point 2
            x[:, i + 6] = (x[:, i + 6] - means[i]) / stds[i]  # point 3
            x[:, i + 9] = (x[:, i + 9] - mins[i]) / (maxs[i] - mins[i])  # centre
            # normalize normal vector
            x[:, i + 12] = (x[:, i + 12] - nmeans[i]) / nstds[i]  # normal1
            x[:, i + 15] = (x[:, i + 15] - nmeans[i]) / nstds[i]  # normal2
            x[:, i + 18] = (x[:, i + 18] - nmeans[i]) / nstds[i]  # normal3
            x[:, i + 21] = (x[:, i + 21] - nmeans_f[i]) / nstds_f[i]  # face normal

reverse

  #  coordinate
        x[:, i] = (x[:, i] + means[i]) * stds[i]  # point 1
        x[:, i + 3] = (x[:, i + 3] + means[i]) * stds[i]  # point 2
        x[:, i + 6] = (x[:, i + 6] + means[i]) * stds[i]  # point 3
        x[:, i + 9] = (x[:, i + 9] + mins[i]) * (maxs[i] - mins[i])  # centre
  #  normal vector
        x[:, i + 12] = (x[:, i + 12] + nmeans[i]) * nstds[i]  # normal1
        x[:, i + 15] = (x[:, i + 15] + nmeans[i]) * nstds[i]  # normal2
        x[:, i + 18] = (x[:, i + 18] + nmeans[i]) * nstds[i]  # normal3
        x[:, i + 21] = (x[:, i + 21] + nmeans_f[i]) * nstds_f[i]  # face normal

Keep in mind that original maxs, mins, means etc. values have be stored before normalization and when used to reverse it.

@shanshanhuang2023
Copy link

I solved the above problem and thank you for your help.

Hello,I also encountered the problem of visualization deformation, but the labels became confused after modifying them according to the denormalization provided by the author. Could you please tell me how you solved this problem?
微信图片_20240929161103

@mykakus
Copy link

mykakus commented Sep 30, 2024

I wrote a script for my personal use. It is pretty rough but it works.
Place it in the same directory (dilated_tooth_seg_net) where all other files are.
In the script specify file path for the 3D object (Model_Teeth) and your checkpoint path for model parameters (ML_parameters)

#-----Define values
Model_Teeth=r'\\.obj' # .obj file path in Teeth3DS dataset example: Teeth3DS\Upper\\0JN50XQR\\0JN50XQR_upper.obj
ML_parameters='\\.ckpt' # model parameter file path

After script is complete two files should be created pred.ply (predicted labels) and gt.ply (ground truth labels)

import trimesh
import torch
import json
import pyfqmr
import numpy as np
# import polyscope as ps
from pathlib import Path
from models.dilated_tooth_seg_network import LitDilatedToothSegmentationNetwork
import random
from utils.teeth_numbering import color_mesh,colors_to_label,fdi_to_label
from lightning.pytorch import seed_everything
import copy
from scipy import spatial

# same function in mesh_dataset
def process_mesh(mesh: trimesh, labels: torch.tensor = None):
    mesh_faces = torch.from_numpy(mesh.faces.copy()).float()
    mesh_triangles = torch.from_numpy(mesh.vertices[mesh.faces]).float()
    mesh_face_normals = torch.from_numpy(mesh.face_normals.copy()).float()
    mesh_vertices_normals = torch.from_numpy(mesh.vertex_normals[mesh.faces]).float()
    if labels is None:
        labels = torch.from_numpy(colors_to_label(mesh.visual.face_colors.copy())).long()
    return mesh_faces, mesh_triangles, mesh_vertices_normals, mesh_face_normals, labels

# similar function as PreTransform in preprocessing.py
def preporces(data):
    mesh_faces, mesh_triangles, mesh_vertices_normals, mesh_face_normals, labels = data
    mesh = trimesh.Trimesh(**trimesh.triangles.to_kwargs(mesh_triangles.cpu().detach().numpy()))

    points = torch.from_numpy(mesh.vertices)
    v_normals = torch.from_numpy(mesh.vertex_normals)

    s, _ = mesh_faces.size()
    x = torch.zeros(s, 24).float()
    x[:, :3] = mesh_triangles[:, 0]
    x[:, 3:6] = mesh_triangles[:, 1]
    x[:, 6:9] = mesh_triangles[:, 2]
    x[:, 9:12] = mesh_triangles.mean(dim=1)
    x[:, 12:15] = mesh_vertices_normals[:, 0]
    x[:, 15:18] = mesh_vertices_normals[:, 1]
    x[:, 18:21] = mesh_vertices_normals[:, 2]
    x[:, 21:] = mesh_face_normals

    maxs = points.max(dim=0)[0]
    mins = points.min(dim=0)[0]
    means = points.mean(axis=0)
    stds = points.std(axis=0)
    nmeans = v_normals.mean(axis=0)
    nstds = v_normals.std(axis=0)
    nmeans_f = mesh_face_normals.mean(axis=0)
    nstds_f = mesh_face_normals.std(axis=0)
    for i in range(3):
        # normalize coordinate
        x[:, i] = (x[:, i] - means[i]) / stds[i]  # point 1
        x[:, i + 3] = (x[:, i + 3] - means[i]) / stds[i]  # point 2
        x[:, i + 6] = (x[:, i + 6] - means[i]) / stds[i]  # point 3
        x[:, i + 9] = (x[:, i + 9] - mins[i]) / (maxs[i] - mins[i])  # centre
        # normalize normal vector
        x[:, i + 12] = (x[:, i + 12] - nmeans[i]) / nstds[i]  # normal1
        x[:, i + 15] = (x[:, i + 15] - nmeans[i]) / nstds[i]  # normal2
        x[:, i + 18] = (x[:, i + 18] - nmeans[i]) / nstds[i]  # normal3
        x[:, i + 21] = (x[:, i + 21] - nmeans_f[i]) / nstds_f[i]  # face normal

    pos = x[:, 9:12]

    return pos, x, labels

# same function(method) in mesh_dataset.Teeth3DSDataset
def Downsample(mesh,labels):
    mesh_simplifier = pyfqmr.Simplify()
    mesh_simplifier.setMesh(mesh.vertices, mesh.faces)
    mesh_simplifier.simplify_mesh(target_count=16000, aggressiveness=3, preserve_border=True, verbose=0,
                                  max_iterations=2000)
    new_positions, new_face, _ = mesh_simplifier.getMesh()
    mesh_simple = trimesh.Trimesh(vertices=new_positions, faces=new_face)
    vertices = mesh_simple.vertices
    faces = mesh_simple.faces
    if faces.shape[0] < 16000:
        fs_diff = 16000 - faces.shape[0]
        faces = np.append(faces, np.zeros((fs_diff, 3), dtype="int"), 0)
    elif faces.shape[0] > 16000:
        mesh_simple = trimesh.Trimesh(vertices=vertices, faces=faces)
        samples, face_index = trimesh.sample.sample_surface_even(mesh_simple, 16000)
        mesh_simple = trimesh.Trimesh(vertices=mesh_simple.vertices, faces=mesh_simple.faces[face_index])
        faces = mesh_simple.faces
        vertices = mesh_simple.vertices
    mesh_simple = trimesh.Trimesh(vertices=vertices, faces=faces)

    mesh_v_mean = mesh.vertices[mesh.faces].mean(axis=1)
    mesh_simple_v = mesh_simple.vertices
    tree = spatial.KDTree(mesh_v_mean)
    query = mesh_simple_v[faces].mean(axis=1)
    distance, index = tree.query(query)
    labels = labels[index].flatten()
    return mesh_simple,labels

# reverse normalization
def PostProces(data_OG_def,x_def):
    _, mesh_triangles, _, mesh_face_normals, _ = data_OG_def
    mesh = trimesh.Trimesh(**trimesh.triangles.to_kwargs(mesh_triangles.cpu().detach().numpy()))
       
    maxs = mesh.vertices.max(axis=0)
    mins =  mesh.vertices.min(axis=0)
    means =  mesh.vertices.mean(axis=0)
    stds =  mesh.vertices.std(axis=0)
    nmeans = mesh.vertex_normals.mean(axis=0)
    nstds = mesh.vertex_normals.std(axis=0)
    nmeans_f = mesh_face_normals.mean(axis=0)
    nstds_f = mesh_face_normals.std(axis=0)
    for i in range(3):
        #  coordinate
        x_def[:, i] = (x_def[:, i] + means[i]) * stds[i]  # point 1
        x_def[:, i + 3] = (x_def[:, i + 3] + means[i]) * stds[i]  # point 2
        x_def[:, i + 6] = (x_def[:, i + 6] + means[i]) * stds[i]  # point 3
        x_def[:, i + 9] = (x_def[:, i + 9] + mins[i]) * (maxs[i] - mins[i])  # centre
        #  normal vector
        x_def[:, i + 12] = (x_def[:, i + 12] + nmeans[i]) * nstds[i]  # normal1
        x_def[:, i + 15] = (x_def[:, i + 15] + nmeans[i]) * nstds[i]  # normal2
        x_def[:, i + 18] = (x_def[:, i + 18] + nmeans[i]) * nstds[i]  # normal3
        x_def[:, i + 21] = (x_def[:, i + 21] + nmeans_f[i]) * nstds_f[i]  # face normal
    return x_def

SEED = 42
use_gpu=True
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
torch.set_float32_matmul_precision('medium')
random.seed(SEED)
seed_everything(SEED, workers=True)

#-----Define values
Model_Teeth=r'\\.obj' # .obj file path in Teeth3DS dataset example: Teeth3DS\Upper\\0JN50XQR\\0JN50XQR_upper.obj
ML_parameters='\\.ckpt' # model parameter file path

#----------Model----------
model = LitDilatedToothSegmentationNetwork.load_from_checkpoint(ML_parameters)
if use_gpu==True:
   model = model.cuda()
   
#----Import model
mesh=trimesh.load(Path(Model_Teeth))
with open(Model_Teeth.replace('.obj', '.json')) as f:
     data = json.load(f)
labels = np.array(data["labels"])
labels = labels[mesh.faces]
labels = labels[:, 0]
labels = fdi_to_label(labels)

#----Downsample
mesh_simple,labels=Downsample(mesh,labels)

#----Preporcess
data = process_mesh(mesh_simple, torch.from_numpy(labels).long())
data_OG=copy.copy(data)
data =preporces(data)

#----Ground truth model labels
ground_truth = data[2]
mesh_gt = color_mesh(mesh_simple, ground_truth.numpy())
mesh_gt.export('gt.ply') # export ground truth 3D model

#----Use model
pre_labels = model.predict_labels(data).cpu().numpy()
x=PostProces(data_OG,data[1]) # Postprocess

triangles = x[:, :9].reshape(-1, 3, 3)
mesh = trimesh.Trimesh(**trimesh.triangles.to_kwargs(triangles.cpu().detach().numpy()))
mesh_pred = color_mesh(mesh, pre_labels)
mesh_pred.export('pred.ply') # export predicted 3D model

#----Show models with highlighted teeths. Requare polyscope (https://github.com/nmwsharp/polyscope) to be installed
# ps.init()
# color_groud=ps.register_surface_mesh('Original', mesh_simple.vertices-mesh_simple.centroid, mesh_simple.faces)
# color_groud.add_color_quantity("groud labels", mesh_gt.visual.face_colors[:,:3]/255, defined_on='faces')
# color_pred=ps.register_surface_mesh('Final model', mesh_pred.vertices-mesh_pred.centroid, mesh_pred.faces)
# color_pred.add_color_quantity("predicted labels", mesh_pred.visual.face_colors[:,:3]/255, defined_on='faces')
# ps.show()
# ps.remove_all_structures()
  • Predicted labels

image

  • Original labels

image

@shanshanhuang2023
Copy link

@mykakus Thank you very much for your reply, it helps me a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants