Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new real-esrgan model real-esrgan-general-x4v3 #1566

Merged
merged 2 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions real-esrgan/general-x4v3/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.16)
project(real-esrgan)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/")

add_definitions(-std=c++17)
add_definitions(-DAPI_EXPORTS)
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
#set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)

#find_package(CUDA REQUIRED)

INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/src/include)

# cuda
FIND_PACKAGE(CUDA REQUIRED)
#INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)

# <------------------------TensorRT Related------------------------->
include_directories(YOUR_TENSORRT_INCLUDE_DIR) # TensorRT-8.6.1.6/include
link_directories(YOUR_TENSORRT_LIB_DIR) # TensorRT-8.6.1.6/lib

# <------------------------OpenCV Related------------------------->
# opencv
FIND_PACKAGE(OpenCV REQUIRED)
INCLUDE_DIRECTORIES(${OpenCV_INCLUDE_DIRS})

set(CMAKE_CXX_STANDARD 17)

add_executable(${PROJECT_NAME} main.cpp)

cuda_add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/src/pixel_shuffle/pixel_shuffle.cu)
target_link_libraries(myplugins nvinfer cudart)


TARGET_LINK_LIBRARIES(${PROJECT_NAME} nvinfer)
TARGET_LINK_LIBRARIES(${PROJECT_NAME} cudart)
TARGET_LINK_LIBRARIES(${PROJECT_NAME} ${OpenCV_LIBS})
TARGET_LINK_LIBRARIES(${PROJECT_NAME} myplugins)
41 changes: 41 additions & 0 deletions real-esrgan/general-x4v3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Real-ESRGAN realesr-general-x4v3 model

## How to Run
0. Replace YOUR_TENSORRT_INCLUDE_DIR and YOUR_TENSORRT_LIB_DIR in CMakeLists.txt with your TensorRT include and lib directories.
1. generate .wts from pytorch with .pt
```
git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN

# Install basicsr - https://github.com/xinntao/BasicSR
# We use BasicSR for both training and inference
pip install basicsr
# facexlib and gfpgan are for face enhancement
pip install facexlib
pip install gfpgan
pip install -r requirements.txt
python setup.py develop
```
download realesr-general-x4v3.pth (and realesr-general-wdn-x4v3.pth if needed) from
https://github.com/xinntao/Real-ESRGAN/releases

```
cp {tensorrtx}/real-esrgan-general-x4v3/gen_wts.py {xinntao}/Real-ESRGAN
cd {xinntao}/Real-ESRGAN
python gen_wts.py
// a file 'real-esrgan.wts' will be generated.
```

**Be aware that if you need both realesr-general-x4v3.pth and realesr-general-wdn-x4v3.pth, please write a Python script to average all weights of realesr-general-x4v3.pth and realesr-general-wdn-x4v3.pth (from {xinntao}/Real-ESRGAN), then save it as a .pth file, and use this new file to generate a .wts file.**

2. build tensorrtx/real-esrgan-general-x4v3 and run

```
cd {tensorrtx}/real-esrgan-general-x4v3/
mkdir build
cd build
cp {xinntao}/Real-ESRGAN/real-esrgan.wts {tensorrtx}/real-esrgan/weights/
cmake ..
make
./real-esrgan your_images_dir
```
87 changes: 87 additions & 0 deletions real-esrgan/general-x4v3/cmake/FindTensorRT.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# source:
# https://github.com/NVIDIA/tensorrt-laboratory/blob/master/cmake/FindTensorRT.cmake

# This module defines the following variables:
#
# ::
#
# TensorRT_INCLUDE_DIRS
# TensorRT_LIBRARIES
# TensorRT_FOUND
#
# ::
#
# TensorRT_VERSION_STRING - version (x.y.z)
# TensorRT_VERSION_MAJOR - major version (x)
# TensorRT_VERSION_MINOR - minor version (y)
# TensorRT_VERSION_PATCH - patch version (z)
#
# Hints
# ^^^^^
# A user may set ``TensorRT_DIR`` to an installation root to tell this module where to look.
#
set(_TensorRT_SEARCHES)

if(TensorRT_DIR)
set(_TensorRT_SEARCH_ROOT PATHS ${TensorRT_DIR} NO_DEFAULT_PATH)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_ROOT)
endif()

# appends some common paths
set(_TensorRT_SEARCH_NORMAL
PATHS "/usr"
)
list(APPEND _TensorRT_SEARCHES _TensorRT_SEARCH_NORMAL)

# Include dir
foreach(search ${_TensorRT_SEARCHES})
find_path(TensorRT_INCLUDE_DIR NAMES NvInfer.h ${${search}} PATH_SUFFIXES include)
endforeach()

if(NOT TensorRT_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_LIBRARY NAMES nvinfer ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()

if(NOT TensorRT_PARSERS_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_NVPARSERS_LIBRARY NAMES nvparsers ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()

if(NOT TensorRT_NVONNXPARSER_LIBRARY)
foreach(search ${_TensorRT_SEARCHES})
find_library(TensorRT_NVONNXPARSER_LIBRARY NAMES nvonnxparser ${${search}} PATH_SUFFIXES lib)
endforeach()
endif()

mark_as_advanced(TensorRT_INCLUDE_DIR)

if(TensorRT_INCLUDE_DIR AND EXISTS "${TensorRT_INCLUDE_DIR}/NvInfer.h")
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$")
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$")
file(STRINGS "${TensorRT_INCLUDE_DIR}/NvInfer.h" TensorRT_PATCH REGEX "^#define NV_TENSORRT_PATCH [0-9]+.*$")

string(REGEX REPLACE "^#define NV_TENSORRT_MAJOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MAJOR "${TensorRT_MAJOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_MINOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MINOR "${TensorRT_MINOR}")
string(REGEX REPLACE "^#define NV_TENSORRT_PATCH ([0-9]+).*$" "\\1" TensorRT_VERSION_PATCH "${TensorRT_PATCH}")
set(TensorRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${TensorRT_VERSION_PATCH}")
endif()

include(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(TensorRT REQUIRED_VARS TensorRT_LIBRARY TensorRT_INCLUDE_DIR VERSION_VAR TensorRT_VERSION_STRING)

if(TensorRT_FOUND)
set(TensorRT_INCLUDE_DIRS ${TensorRT_INCLUDE_DIR})

if(NOT TensorRT_LIBRARIES)
set(TensorRT_LIBRARIES ${TensorRT_LIBRARY} ${TensorRT_NVONNXPARSER_LIBRARY} ${TensorRT_NVPARSERS_LIBRARY})
endif()

if(NOT TARGET TensorRT::TensorRT)
add_library(TensorRT::TensorRT UNKNOWN IMPORTED)
set_target_properties(TensorRT::TensorRT PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${TensorRT_INCLUDE_DIRS}")
set_property(TARGET TensorRT::TensorRT APPEND PROPERTY IMPORTED_LOCATION "${TensorRT_LIBRARY}")
endif()
endif()
136 changes: 136 additions & 0 deletions real-esrgan/general-x4v3/gen_wts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import argparse
import os
import struct
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, help='Input image or folder')
parser.add_argument(
'-n',
'--model_name',
type=str,
default='realesr-general-x4v3',
help=('RealESRGAN_x2plus Model names: '
'realesr-animevideov3 | realesr-general-x4v3'))
parser.add_argument('-o', '--output', type=str, help='Output folder')
parser.add_argument(
'-dn',
'--denoise_strength',
type=float,
default=0.5,
help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
'Only used for the realesr-general-x4v3 model'))
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
parser.add_argument(
'--model_path', type=str, default=None, help='[Option] Model path. Usually, you do not need to specify it')
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
parser.add_argument(
'--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
parser.add_argument(
'--alpha_upsampler',
type=str,
default='realesrgan',
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
parser.add_argument(
'--ext',
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
parser.add_argument(
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')

args = parser.parse_args()

# determine models according to model names
args.model_name = args.model_name.split('.')[0]
if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
]

# determine model paths
if args.model_path is not None:
model_path = args.model_path
else:
model_path = os.path.join('weights', args.model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
# model_path will be updated
model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)

# use dni to control the denoise strength
dni_weight = None
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
# wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
# model_path = [model_path, wdn_model_path]
# dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-x4v3-cat')
dni_weight = None

# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32,
gpu_id=args.gpu_id)

if os.path.isfile('real-esrgan.wts'):
print('Already, real-esrgan.wts file exists.')
else:
print('making real-esrgan.wts file ...')
f = open("real-esrgan.wts", 'w')
f.write("{}\n".format(len(upsampler.model.state_dict().keys())))
for k, v in upsampler.model.state_dict().items():
print('key: ', k)
print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")
print('Completed real-esrgan.wts file!')


if __name__ == '__main__':
main()
Loading
Loading