Skip to content

Commit

Permalink
added test epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
rostyslavhereha committed Feb 1, 2024
1 parent dc7c991 commit f0a6234
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions auto_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

from auto_training.config_factories.pvt_config_factory import make_pvt_cfg
from mmdet import __version__
from mmdet.apis import init_random_seed, set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.apis import init_random_seed, set_random_seed, train_detector, single_gpu_test
from mmdet.datasets import build_dataset, build_dataloader
from mmdet.models import build_detector
from mmdet.utils import (collect_env, get_device, get_root_logger,
replace_cfg_vals, setup_multi_processes,
update_data_root)
update_data_root, build_dp)

def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
Expand Down Expand Up @@ -252,6 +252,19 @@ def main():
timestamp=timestamp,
meta=meta)

test_dataloader_default_args = dict(
samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)

dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(dataset, **test_dataloader_default_args)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
outputs = single_gpu_test(model, data_loader)
metric = dataset.evaluate(outputs)
logger.info(
f"Epoch(test) [{args.max_epochs}][{args.max_epochs}] "
f"\t {', '.join(f'{key}: {value}' for key, value in metric.items())}"
)


if __name__ == '__main__':
main()

0 comments on commit f0a6234

Please sign in to comment.