Skip to content

Commit

Permalink
Training updated
Browse files Browse the repository at this point in the history
  • Loading branch information
Lojze Žust committed Jul 23, 2022
1 parent fea97f6 commit f3ef327
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Our work will be presented at the *IROS 2022* conference in Kyoto, Japan.

<p align="center">
<img src="figures/comparison.gif" alt="Comparison WaSR - WaSR-T">
Comparison between WaSR (single-frame) and WaSR-T (temporal context).
Comparison between WaSR (single-frame) and WaSR-T (temporal context) on hard examples.
</p>


Expand Down
2 changes: 0 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def get_arguments(input_args=None):
help="Name of the model. Used to create model and log directories inside the output directory.")
parser.add_argument("--pretrained-weights", type=str, default=None,
help="Path to the pretrained weights to be used.")
parser.add_argument("--architecture", type=str, choices=M.models, default=ARCHITECTURE,
help="Which architecture to use.")
parser.add_argument("--monitor-metric", type=str, default=MONITOR_VAR,
help="Validation metric to monitor for early stopping and best model saving.")
parser.add_argument("--monitor-metric-mode", type=str, default=MONITOR_VAR_MODE, choices=['min', 'max'],
Expand Down
2 changes: 1 addition & 1 deletion wasr_t/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .loss import focal_loss, water_obstacle_separation_loss
from .metrics import PixelAccuracy, ClassIoU

NUM_EPOCHS = 50
NUM_EPOCHS = 100
LEARNING_RATE = 1e-6
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
Expand Down
7 changes: 1 addition & 6 deletions wasr_t/wasr_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ def wasr_temporal_resnet101(num_classes=3, pretrained=True, sequential=False):

class WaSRT(nn.Module):
"""WaSR-T model"""
def __init__(self, backbone, decoder, backbone_grad_steps=3, imu=False, sequential=False):
def __init__(self, backbone, decoder, backbone_grad_steps=3, sequential=False):
super(WaSRT, self).__init__()

self.imu = imu

self.backbone = backbone
self.decoder = decoder
self.backbone_grad_steps = backbone_grad_steps
Expand Down Expand Up @@ -92,9 +90,6 @@ def forward_unrolled(self, x):
for f in extract_feats:
feats_hist[f] = torch.stack(feats_hist[f], 1)

if self.imu:
features['imu_mask'] = x['imu_mask']

x = self.decoder(features, feats_hist)

# Return segmentation map and aux feature map
Expand Down

0 comments on commit f3ef327

Please sign in to comment.