diff --git a/README.md b/README.md index 33d66af..acd072c 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Our work will be presented at the *IROS 2022* conference in Kyoto, Japan.
- Comparison between WaSR (single-frame) and WaSR-T (temporal context). + Comparison between WaSR (single-frame) and WaSR-T (temporal context) on hard examples.
diff --git a/train.py b/train.py index ebc0d58..d323132 100644 --- a/train.py +++ b/train.py @@ -85,8 +85,6 @@ def get_arguments(input_args=None): help="Name of the model. Used to create model and log directories inside the output directory.") parser.add_argument("--pretrained-weights", type=str, default=None, help="Path to the pretrained weights to be used.") - parser.add_argument("--architecture", type=str, choices=M.models, default=ARCHITECTURE, - help="Which architecture to use.") parser.add_argument("--monitor-metric", type=str, default=MONITOR_VAR, help="Validation metric to monitor for early stopping and best model saving.") parser.add_argument("--monitor-metric-mode", type=str, default=MONITOR_VAR_MODE, choices=['min', 'max'], diff --git a/wasr_t/train.py b/wasr_t/train.py index 1cf4ca9..8163885 100644 --- a/wasr_t/train.py +++ b/wasr_t/train.py @@ -8,7 +8,7 @@ from .loss import focal_loss, water_obstacle_separation_loss from .metrics import PixelAccuracy, ClassIoU -NUM_EPOCHS = 50 +NUM_EPOCHS = 100 LEARNING_RATE = 1e-6 MOMENTUM = 0.9 WEIGHT_DECAY = 1e-6 diff --git a/wasr_t/wasr_t.py b/wasr_t/wasr_t.py index 19fc239..ada30a9 100644 --- a/wasr_t/wasr_t.py +++ b/wasr_t/wasr_t.py @@ -44,11 +44,9 @@ def wasr_temporal_resnet101(num_classes=3, pretrained=True, sequential=False): class WaSRT(nn.Module): """WaSR-T model""" - def __init__(self, backbone, decoder, backbone_grad_steps=3, imu=False, sequential=False): + def __init__(self, backbone, decoder, backbone_grad_steps=3, sequential=False): super(WaSRT, self).__init__() - self.imu = imu - self.backbone = backbone self.decoder = decoder self.backbone_grad_steps = backbone_grad_steps @@ -92,9 +90,6 @@ def forward_unrolled(self, x): for f in extract_feats: feats_hist[f] = torch.stack(feats_hist[f], 1) - if self.imu: - features['imu_mask'] = x['imu_mask'] - x = self.decoder(features, feats_hist) # Return segmentation map and aux feature map