diff --git a/dnnlib/submission/submit.py b/dnnlib/submission/submit.py index 514647d..d2ce8a5 100644 --- a/dnnlib/submission/submit.py +++ b/dnnlib/submission/submit.py @@ -129,10 +129,9 @@ def get_path_from_template(path_template: str, path_type: PathType = PathType.AU # return correctly formatted path if path_type == PathType.WINDOWS: return str(pathlib.PureWindowsPath(path_template)) - elif path_type == PathType.LINUX: + if path_type == PathType.LINUX: return str(pathlib.PurePosixPath(path_template)) - else: - raise RuntimeError("Unknown platform") + raise RuntimeError("Unknown platform") def get_template_from_path(path: str) -> str: @@ -158,9 +157,9 @@ def get_user_name(): """Get the current user name.""" if _user_name_override is not None: return _user_name_override - elif platform.system() == "Windows": + if platform.system() == "Windows": return os.getlogin() - elif platform.system() == "Linux": + if platform.system() == "Linux": try: import pwd return pwd.getpwuid(os.geteuid()).pw_name @@ -283,15 +282,14 @@ def run_wrapper(submit_config: SubmitConfig) -> None: except: if is_local: raise - else: - traceback.print_exc() + traceback.print_exc() - log_src = os.path.join(submit_config.run_dir, "log.txt") - log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) - shutil.copyfile(log_src, log_dst) + log_src = os.path.join(submit_config.run_dir, "log.txt") + log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) + shutil.copyfile(log_src, log_dst) - # Defer sys.exit(1) to happen after we close the logs and create a _finished.txt - exit_with_errcode = True + # Defer sys.exit(1) to happen after we close the logs and create a _finished.txt + exit_with_errcode = True finally: open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() diff --git a/dnnlib/tflib/custom_ops.py b/dnnlib/tflib/custom_ops.py index f87c0d8..53dc292 100644 --- a/dnnlib/tflib/custom_ops.py +++ b/dnnlib/tflib/custom_ops.py @@ -84,7 +84,7 @@ def _prepare_nvcc_cli(opts): #---------------------------------------------------------------------------- # Main entry point. -_plugin_cache = dict() +_plugin_cache = {} def get_plugin(cuda_file): cuda_file_base = os.path.basename(cuda_file) diff --git a/dnnlib/tflib/network.py b/dnnlib/tflib/network.py index 409babb..045cca0 100644 --- a/dnnlib/tflib/network.py +++ b/dnnlib/tflib/network.py @@ -23,7 +23,7 @@ from .tfutil import TfExpression, TfExpressionEx _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. -_import_module_src = dict() # Source code for temporary modules created during pickle import. +_import_module_src = {} # Source code for temporary modules created during pickle import. def import_handler(handler_func): @@ -120,7 +120,7 @@ def _init_fields(self) -> None: self._build_func = None # User-supplied build function that constructs the network. self._build_func_name = None # Name of the build function. self._build_module_src = None # Full source code of the module containing the build function. - self._run_cache = dict() # Cached graph data for Network.run(). + self._run_cache = {} # Cached graph data for Network.run(). def _init_graph(self) -> None: # Collect inputs. @@ -254,7 +254,7 @@ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[ def __getstate__(self) -> dict: """Pickle export.""" - state = dict() + state = {} state["version"] = 4 state["name"] = self.name state["static_kwargs"] = dict(self.static_kwargs) diff --git a/dnnlib/tflib/tfutil.py b/dnnlib/tflib/tfutil.py index 1127c7b..2b2a3c9 100644 --- a/dnnlib/tflib/tfutil.py +++ b/dnnlib/tflib/tfutil.py @@ -83,7 +83,7 @@ def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: def _sanitize_tf_config(config_dict: dict = None) -> dict: # Defaults. - cfg = dict() + cfg = {} cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. @@ -227,20 +227,24 @@ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwar return var -def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): +def convert_images_from_uint8(images, drange=None, nhwc_to_nchw=False): """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. Can be used as an input transformation for Network.run(). """ + if drange is None: + drange = [-1,1] images = tf.cast(images, tf.float32) if nhwc_to_nchw: images = tf.transpose(images, [0, 3, 1, 2]) return images * ((drange[1] - drange[0]) / 255) + drange[0] -def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): +def convert_images_to_uint8(images, drange=None, nchw_to_nhwc=False, shrink=1): """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. Can be used as an output transformation for Network.run(). """ + if drange is None: + drange = [-1,1] images = tf.cast(images, tf.float32) if shrink > 1: ksize = [1, 1, shrink, shrink] diff --git a/module/cnf.py b/module/cnf.py index 6237151..265f37b 100644 --- a/module/cnf.py +++ b/module/cnf.py @@ -25,10 +25,9 @@ def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_ # print(x.shape) x = self.chain[i](x, context, logpx, integration_times, reverse) return x - else: - for i in inds: - x, logpx = self.chain[i](x, context, logpx, integration_times, reverse) - return x, logpx + for i in inds: + x, logpx = self.chain[i](x, context, logpx, integration_times, reverse) + return x, logpx class CNF(nn.Module): @@ -115,8 +114,7 @@ def forward(self, x, context=None, logpx=None, integration_times=None, reverse=F if logpx is not None: return z_t, logpz_t - else: - return z_t + return z_t def num_evals(self): return self.odefunc._num_evals.item() diff --git a/module/normalization.py b/module/normalization.py index abde0d8..5f93257 100644 --- a/module/normalization.py +++ b/module/normalization.py @@ -40,8 +40,7 @@ def reset_parameters(self): def forward(self, x, c=None, logpx=None, reverse=False): if reverse: return self._reverse(x, logpx) - else: - return self._forward(x, logpx) + return self._forward(x, logpx) def _forward(self, x, logpx=None): num_channels = x.size(-1) @@ -87,8 +86,7 @@ def _forward(self, x, logpx=None): if logpx is None: return y - else: - return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True) + return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True) def _reverse(self, y, logpy=None): used_mean = self.running_mean @@ -105,8 +103,7 @@ def _reverse(self, y, logpy=None): if logpy is None: return x - else: - return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True) + return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True) def _logdetgrad(self, x, used_var): logdetgrad = -0.5 * torch.log(used_var + self.eps) diff --git a/module/odefunc.py b/module/odefunc.py index 7ded858..351c325 100644 --- a/module/odefunc.py +++ b/module/odefunc.py @@ -134,9 +134,8 @@ def forward(self, t, states): divergence = self.divergence_fn(dy, y, e=self._e).unsqueeze(-1) return dy, -divergence, torch.zeros_like(c).requires_grad_(True) - elif len(states) == 2: # unconditional CNF + if len(states) == 2: # unconditional CNF dy = self.diffeq(t, y) divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1) return dy, -divergence - else: - assert 0, "`len(states)` should be 2 or 3" \ No newline at end of file + assert 0, "`len(states)` should be 2 or 3" \ No newline at end of file diff --git a/module/utils.py b/module/utils.py index f35176c..aa46b23 100644 --- a/module/utils.py +++ b/module/utils.py @@ -106,7 +106,9 @@ def set_random_seed(seed): # Visualization -def visualize_point_clouds(pts, gtr, idx, pert_order=[0, 1, 2]): +def visualize_point_clouds(pts, gtr, idx, pert_order=None): + if pert_order is None: + pert_order = [0, 1, 2] pts = pts.cpu().detach().numpy()[:, pert_order] gtr = gtr.cpu().detach().numpy()[:, pert_order] diff --git a/pretrained_networks.py b/pretrained_networks.py index 40ccfd9..5297466 100644 --- a/pretrained_networks.py +++ b/pretrained_networks.py @@ -59,7 +59,7 @@ def get_path_or_url(path_or_gdrive_path): #---------------------------------------------------------------------------- -_cached_networks = dict() +_cached_networks = {} def load_networks(path_or_gdrive_path): path_or_url = get_path_or_url(path_or_gdrive_path) diff --git a/run_generator.py b/run_generator.py index 339796c..fb7e7c9 100644 --- a/run_generator.py +++ b/run_generator.py @@ -52,7 +52,7 @@ def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_ all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component] all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component] all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component] - w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} # [layer, component] + w_dict = dict(zip(all_seeds, list(all_w))) # [layer, component] print('Generating images...') all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel] diff --git a/utils.py b/utils.py index 63b91ea..ecaef2d 100644 --- a/utils.py +++ b/utils.py @@ -10,7 +10,6 @@ import numpy as np import PIL.Image -import dnnlib import dnnlib.tflib as tflib import os import re @@ -142,7 +141,9 @@ def get_style_loss(base_style, gram_target): #---------------------------------------------------------------------------- -def generate_im_official(network_pkl='gdrive:networks/stylegan2-ffhq-config-f.pkl', seeds=[22], truncation_psi=0.5): +def generate_im_official(network_pkl='gdrive:networks/stylegan2-ffhq-config-f.pkl', seeds=None, truncation_psi=0.5): + if seeds is None: + seeds = [22] print('Loading networks from "%s"...' % network_pkl) _G, _D, Gs = pretrained_networks.load_networks(network_pkl) noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]