diff --git a/intrinsic_compositing/shading/pipeline.py b/intrinsic_compositing/shading/pipeline.py index 0c9d412..56345f6 100644 --- a/intrinsic_compositing/shading/pipeline.py +++ b/intrinsic_compositing/shading/pipeline.py @@ -13,7 +13,7 @@ def load_reshading_model(path, device='cuda'): if path == 'paper_weights': state_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/shading_paper_weights.pt', map_location=device, progress=True) - if path == 'further_trained' + if path == 'further_trained': state_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/IntrinsicCompositing/releases/download/1.0.0/further_trained.pt', map_location=device, progress=True) else: state_dict = torch.load(path)