Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.

Commit 1496971

Browse files
Yanghan Wangfacebook-github-bot
Yanghan Wang
authored andcommitted
use "legacy" dataclass at operator level and separate TestNetOutput from TrainNetOutput
Summary: Pull Request resolved: #444 Differential Revision: D41828774 fbshipit-source-id: 833dea0e79eccdb8396bafb6c73f5255fa3cfddb
1 parent 02723f2 commit 1496971

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

d2go/trainer/api.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@
1111

1212
from d2go.evaluation.api import AccuracyDict, MetricsDict
1313

14-
# TODO (T127368935) Split to TrainNetOutput and TestNetOutput
14+
1515
@dataclass
1616
class TrainNetOutput:
1717
accuracy: AccuracyDict[float]
1818
metrics: MetricsDict[float]
19-
# Optional, because we use None to distinguish "not used" from
20-
# empty model configs. With T127368935, this should be reverted to dict.
21-
model_configs: Optional[Dict[str, str]]
19+
model_configs: Dict[str, str]
20+
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
21+
tensorboard_log_dir: Optional[str] = None
22+
23+
24+
@dataclass
25+
class TestNetOutput:
26+
accuracy: AccuracyDict[float]
27+
metrics: MetricsDict[float]
2228
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
2329
tensorboard_log_dir: Optional[str] = None
2430

tools/lightning_train_net.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
1313
from d2go.runner.lightning_task import DefaultTask
1414
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
15-
from d2go.trainer.api import TrainNetOutput
15+
from d2go.trainer.api import TestNetOutput, TrainNetOutput
1616
from d2go.trainer.helper import parse_precision_from_string
1717
from d2go.trainer.lightning.training_loop import _do_test, _do_train
1818
from detectron2.utils.file_io import PathManager
@@ -103,7 +103,7 @@ def main(
103103
output_dir: str,
104104
runner_class: Union[str, Type[DefaultTask]],
105105
eval_only: bool = False,
106-
) -> TrainNetOutput:
106+
) -> Union[TrainNetOutput, TestNetOutput]:
107107
"""Main function for launching a training with lightning trainer
108108
Args:
109109
cfg: D2go config node
@@ -123,18 +123,22 @@ def main(
123123
logger.info(f"Resuming training from checkpoint: {last_checkpoint}.")
124124

125125
trainer = pl.Trainer(**trainer_params)
126-
model_configs = None
126+
127127
if eval_only:
128128
_do_test(trainer, task)
129+
return TestNetOutput(
130+
tensorboard_log_dir=trainer_params["logger"].log_dir,
131+
accuracy=task.eval_res,
132+
metrics=task.eval_res,
133+
)
129134
else:
130135
model_configs = _do_train(cfg, trainer, task)
131-
132-
return TrainNetOutput(
133-
tensorboard_log_dir=trainer_params["logger"].log_dir,
134-
accuracy=task.eval_res,
135-
metrics=task.eval_res,
136-
model_configs=model_configs,
137-
)
136+
return TrainNetOutput(
137+
tensorboard_log_dir=trainer_params["logger"].log_dir,
138+
accuracy=task.eval_res,
139+
metrics=task.eval_res,
140+
model_configs=model_configs,
141+
)
138142

139143

140144
def argument_parser():

tools/train_net.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import sys
1010
from typing import List, Type, Union
1111

12-
import detectron2.utils.comm as comm
1312
from d2go.config import CfgNode
1413
from d2go.distributed import launch
1514
from d2go.runner import BaseRunner
@@ -22,7 +21,7 @@
2221
setup_before_launch,
2322
setup_root_logger,
2423
)
25-
from d2go.trainer.api import TrainNetOutput
24+
from d2go.trainer.api import TestNetOutput, TrainNetOutput
2625
from d2go.trainer.fsdp import create_ddp_model_with_sharding
2726
from d2go.utils.misc import (
2827
dump_trained_model_configs,
@@ -40,7 +39,7 @@ def main(
4039
runner_class: Union[str, Type[BaseRunner]],
4140
eval_only: bool = False,
4241
resume: bool = True, # NOTE: always enable resume when running on cluster
43-
) -> TrainNetOutput:
42+
) -> Union[TrainNetOutput, TestNetOutput]:
4443
runner = setup_after_launch(cfg, output_dir, runner_class)
4544

4645
model = runner.build_model(cfg)
@@ -58,9 +57,8 @@ def main(
5857
model.eval()
5958
metrics = runner.do_test(cfg, model, train_iter=train_iter)
6059
print_metrics_table(metrics)
61-
return TrainNetOutput(
60+
return TestNetOutput(
6261
accuracy=metrics,
63-
model_configs={},
6462
metrics=metrics,
6563
)
6664

0 commit comments

Comments
 (0)