From 06f6b0dad8cc1866e205c58c1977fa3163565bb0 Mon Sep 17 00:00:00 2001 From: OliverGrace <5617905+OliverGrace@users.noreply.github.com> Date: Tue, 7 Feb 2023 03:48:44 +0800 Subject: [PATCH] fix test --- hat/archs/__init__.py | 2 +- hat/data/__init__.py | 2 +- hat/data/imagenet_paired_dataset.py | 3 ++- hat/models/__init__.py | 2 +- hat/test.py | 6 +++--- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/hat/archs/__init__.py b/hat/archs/__init__.py index e5af1ecf..42ec0694 100644 --- a/hat/archs/__init__.py +++ b/hat/archs/__init__.py @@ -8,4 +8,4 @@ arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] # import all the arch modules -_arch_modules = [importlib.import_module(f'hat.archs.{file_name}') for file_name in arch_filenames] +_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] diff --git a/hat/data/__init__.py b/hat/data/__init__.py index f4819d53..b22fe2bb 100644 --- a/hat/data/__init__.py +++ b/hat/data/__init__.py @@ -8,4 +8,4 @@ data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] # import all the dataset modules -_dataset_modules = [importlib.import_module(f'hat.data.{file_name}') for file_name in dataset_filenames] +_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] diff --git a/hat/data/imagenet_paired_dataset.py b/hat/data/imagenet_paired_dataset.py index 91407377..e63319c1 100644 --- a/hat/data/imagenet_paired_dataset.py +++ b/hat/data/imagenet_paired_dataset.py @@ -7,7 +7,8 @@ from basicsr.data.data_util import paths_from_lmdb, scandir from basicsr.data.transforms import augment, paired_random_crop from basicsr.utils import FileClient, imfrombytes, img2tensor -from basicsr.utils.matlab_functions import imresize, rgb2ycbcr +from basicsr.utils.matlab_functions import imresize +from basicsr.utils.color_util import rgb2ycbcr from basicsr.utils.registry import DATASET_REGISTRY diff --git a/hat/models/__init__.py b/hat/models/__init__.py index fc0917a2..a92d6ace 100644 --- a/hat/models/__init__.py +++ b/hat/models/__init__.py @@ -8,4 +8,4 @@ model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules -_model_modules = [importlib.import_module(f'hat.models.{file_name}') for file_name in model_filenames] +_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] diff --git a/hat/test.py b/hat/test.py index b334817a..ac070ffb 100644 --- a/hat/test.py +++ b/hat/test.py @@ -1,9 +1,9 @@ # flake8: noqa import os.path as osp -import hat.archs -import hat.data -import hat.models +import archs +import data +import models from basicsr.test import test_pipeline if __name__ == '__main__':