Skip to content

Commit

Permalink
ENH: Add model path by default in the docker container
Browse files Browse the repository at this point in the history
The trained model is download and extracted automatically when building the docker container
  • Loading branch information
juanprietob committed Apr 12, 2021
1 parent 6804b86 commit 90c3290
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
5 changes: 5 additions & 0 deletions Docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ RUN unzip vtk8.2.0.zip
RUN pip install itk sklearn pandas matplotlib

WORKDIR /app

RUN git clone https://github.com/DCBIA-OrthoLab/fly-by-cnn.git

RUN wget https://github.com/DCBIA-OrthoLab/fly-by-cnn/releases/download/2.1/u_seg_nn_v3.0.zip
RUN unzip u_seg_nn_v3.0.zip


ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.6/dist-packages/vtkmodules/
ENV MESA_GL_VERSION_OVERRIDE=3.2
9 changes: 6 additions & 3 deletions src/py/predict_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

parser = argparse.ArgumentParser(description='Predict an input with a trained neural network', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--surf', type=str, help='Input surface mesh to label', required=True)
parser.add_argument('--model', type=str, help='Model to do segmentation', required=True)
parser.add_argument('--model', type=str, help='Model to do segmentation', default="/app/u_seg_nn_v3.0")
parser.add_argument('--out', type=str, help='Output model with labels', default="out.vtk")

args = parser.parse_args()
Expand Down Expand Up @@ -42,8 +42,11 @@
flyby_features.removeActors()


model = tf.keras.models.load_model(args.model, custom_objects={'tf': tf})
model.summary()
if os.path.exist(args.model):
model = tf.keras.models.load_model(args.model, custom_objects={'tf': tf})
model.summary()
else:
print("Please set the model directory to a valid path", file=sys.stderr)

print("Predict ...")
img_predict_np = model.predict(img_np)
Expand Down

0 comments on commit 90c3290

Please sign in to comment.